From e45ef46424155812ce5be493fac400d67d1b05e0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 12 May 2020 15:42:26 -0500 Subject: [PATCH] Refactor and add CompositeTextBuilder --- composite_fields.go | 47 +++------------------ composite_type.go | 92 ++++++++++++++++++++++++++++++++++++++++++ composite_type_test.go | 43 ++++++++++++++++++++ 3 files changed, 141 insertions(+), 41 deletions(-) diff --git a/composite_fields.go b/composite_fields.go index b2d9f844..af7bab1e 100644 --- a/composite_fields.go +++ b/composite_fields.go @@ -58,52 +58,17 @@ func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error { // EncodeText encodes composite fields into the text format. Prefer registering a CompositeType to using // CompositeFields to encode directly. func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - buf = append(buf, '(') - - fieldBuf := make([]byte, 0, 32) + b := NewCompositeTextBuilder(ci, buf) for _, f := range cf { - if f != nil { - fieldBuf = fieldBuf[0:0] - if textEncoder, ok := f.(TextEncoder); ok { - var err error - fieldBuf, err = textEncoder.EncodeText(ci, fieldBuf) - if err != nil { - return nil, err - } - if fieldBuf != nil { - buf = append(buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...) - } - } else { - dt, ok := ci.DataTypeForValue(f) - if !ok { - return nil, errors.Errorf("Unknown data type for %#v", f) - } - - err := dt.Value.Set(f) - if err != nil { - return nil, err - } - - if textEncoder, ok := dt.Value.(TextEncoder); ok { - var err error - fieldBuf, err = textEncoder.EncodeText(ci, fieldBuf) - if err != nil { - return nil, err - } - if fieldBuf != nil { - buf = append(buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...) - } - } else { - return nil, errors.Errorf("Cannot encode text format for %v", f) - } - } + if textEncoder, ok := f.(TextEncoder); ok { + b.AppendEncoder(textEncoder) + } else { + b.AppendValue(f) } - buf = append(buf, ',') } - buf[len(buf)-1] = ')' - return buf, nil + return b.Finish() } // EncodeBinary encodes composite fields into the binary format. Unlike CompositeType the schema of the destination is diff --git a/composite_type.go b/composite_type.go index f01e8e64..6baa639a 100644 --- a/composite_type.go +++ b/composite_type.go @@ -177,6 +177,27 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { return nil } +func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { + if buf == nil { + dst.status = Null + return nil + } + + scanner := NewCompositeTextScanner(ci, buf) + + for _, f := range dst.fields { + scanner.ScanDecoder(f) + } + + if scanner.Err() != nil { + return scanner.Err() + } + + dst.status = Present + + return nil +} + type CompositeBinaryScanner struct { ci *ConnInfo rp int @@ -474,6 +495,77 @@ func (b *CompositeBinaryBuilder) Finish() ([]byte, error) { return b.buf, nil } +type CompositeTextBuilder struct { + ci *ConnInfo + buf []byte + startIdx int + fieldCount uint32 + err error + fieldBuf [32]byte +} + +func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder { + buf = append(buf, '(') // allocate room for number of fields + return &CompositeTextBuilder{ci: ci, buf: buf} +} + +func (b *CompositeTextBuilder) AppendValue(field interface{}) { + if b.err != nil { + return + } + + if field == nil { + b.buf = append(b.buf, ',') + return + } + + dt, ok := b.ci.DataTypeForValue(field) + if !ok { + b.err = errors.Errorf("unknown data type for field: %v", field) + return + } + + err := dt.Value.Set(field) + if err != nil { + b.err = err + return + } + + textEncoder, ok := dt.Value.(TextEncoder) + if !ok { + b.err = errors.Errorf("unable to encode text for value: %v", field) + return + } + + b.AppendEncoder(textEncoder) +} + +func (b *CompositeTextBuilder) AppendEncoder(field TextEncoder) { + if b.err != nil { + return + } + + fieldBuf, err := field.EncodeText(b.ci, b.fieldBuf[0:0]) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + b.buf = append(b.buf, QuoteCompositeFieldIfNeeded(string(fieldBuf))...) + } + + b.buf = append(b.buf, ',') +} + +func (b *CompositeTextBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + b.buf[len(b.buf)-1] = ')' + return b.buf, nil +} + var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) func quoteCompositeField(src string) string { diff --git a/composite_type_test.go b/composite_type_test.go index 92ecc849..17d34251 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -7,8 +7,10 @@ import ( "testing" "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" pgx "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCompositeTypeSetAndGet(t *testing.T) { @@ -130,6 +132,47 @@ func TestCompositeTypeAssignTo(t *testing.T) { } } +func TestCompositeTypeTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists ct_test; + +create type ct_test as ( + a text, + b int4 +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type ct_test") + + var oid uint32 + err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid) + require.NoError(t, err) + + defer conn.Exec(context.Background(), "drop type ct_test") + + ct := pgtype.NewCompositeType("ct_test", &pgtype.Text{}, &pgtype.Int4{}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: "ct_test", OID: oid}) + + // Use simple protocol to force text or binary encoding + simpleProtocols := []bool{true, false} + + var a string + var b int32 + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{"hi", int32(42)}, + ).Scan( + []interface{}{&a, &b}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol) + } + } +} + func Example_composite() { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil {