diff --git a/composite.go b/composite.go index 4e6b68ca..59549736 100644 --- a/composite.go +++ b/composite.go @@ -233,6 +233,96 @@ func (cfs *CompositeBinaryScanner) Err() error { return cfs.err } +type CompositeTextScanner struct { + rp int + src []byte + + fieldBytes []byte + err error +} + +// NewCompositeTextScanner a scanner over a text encoded composite balue. +func NewCompositeTextScanner(src []byte) (CompositeTextScanner, error) { + if len(src) < 2 { + return CompositeTextScanner{}, errors.Errorf("Record incomplete %v", src) + } + + if src[0] != '(' { + return CompositeTextScanner{}, errors.Errorf("composite text format must start with '('") + } + + if src[len(src)-1] != ')' { + return CompositeTextScanner{}, errors.Errorf("composite text format must end with ')'") + } + + return CompositeTextScanner{ + rp: 1, + src: src, + }, nil +} + +// Scan advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Scan returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeTextScanner) Scan() bool { + if cfs.err != nil { + return false + } + + if cfs.rp == len(cfs.src) { + return false + } + + switch cfs.src[cfs.rp] { + case ',', ')': // null + cfs.rp++ + cfs.fieldBytes = nil + return true + case '"': // quoted value + cfs.rp++ + cfs.fieldBytes = make([]byte, 0, 16) + for { + ch := cfs.src[cfs.rp] + + if ch == '"' { + cfs.rp++ + if cfs.src[cfs.rp] == '"' { + cfs.fieldBytes = append(cfs.fieldBytes, '"') + cfs.rp++ + } else { + break + } + } else { + cfs.fieldBytes = append(cfs.fieldBytes, ch) + cfs.rp++ + } + } + cfs.rp++ + return true + default: // unquoted value + start := cfs.rp + for { + ch := cfs.src[cfs.rp] + if ch == ',' || ch == ')' { + break + } + cfs.rp++ + } + cfs.fieldBytes = cfs.src[start:cfs.rp] + cfs.rp++ + return true + } +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeTextScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeTextScanner) Err() error { + return cfs.err +} + // RecordStart adds record header to the buf func RecordStart(buf []byte, fieldCount int) []byte { return pgio.AppendUint32(buf, uint32(fieldCount)) diff --git a/composite_fields.go b/composite_fields.go new file mode 100644 index 00000000..64a17b55 --- /dev/null +++ b/composite_fields.go @@ -0,0 +1,76 @@ +package pgtype + +import ( + errors "golang.org/x/xerrors" +) + +// CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a +// nullable value use a *CompositeFields. It will be set to nil in case of null. +type CompositeFields []interface{} + +func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error { + if len(cf) == 0 { + return errors.Errorf("cannot decode into empty CompositeFields") + } + + if src == nil { + return errors.Errorf("cannot decode unexpected null into CompositeFields") + } + + scanner, err := NewCompositeBinaryScanner(src) + if err != nil { + return err + } + if len(cf) != scanner.FieldCount() { + return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), scanner.FieldCount()) + } + + for i := 0; scanner.Scan(); i++ { + err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), cf[i]) + if err != nil { + return err + } + } + + if scanner.Err() != nil { + return scanner.Err() + } + + return nil +} + +func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error { + if len(cf) == 0 { + return errors.Errorf("cannot decode into empty CompositeFields") + } + + if src == nil { + return errors.Errorf("cannot decode unexpected null into CompositeFields") + } + + scanner, err := NewCompositeTextScanner(src) + if err != nil { + return err + } + + fieldCount := 0 + + for i := 0; scanner.Scan(); i++ { + err := ci.Scan(0, TextFormatCode, scanner.Bytes(), cf[i]) + if err != nil { + return err + } + + fieldCount += 1 + } + + if scanner.Err() != nil { + return scanner.Err() + } + + if len(cf) != fieldCount { + return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(cf), fieldCount) + } + + return nil +} diff --git a/composite_fields_test.go b/composite_fields_test.go new file mode 100644 index 00000000..d53e48ec --- /dev/null +++ b/composite_fields_test.go @@ -0,0 +1,126 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" +) + +func TestCompositeFieldsDecode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + formats := []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} + + // Assorted values + { + var a int32 + var b string + var c float64 + + for _, format := range formats { + err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b, &c}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.EqualValuesf(t, 1, a, "Format: %v", format) + assert.EqualValuesf(t, "hi", b, "Format: %v", format) + assert.EqualValuesf(t, 2.1, c, "Format: %v", format) + } + } + + // nulls, string "null", and empty string fields + { + var a pgtype.Text + var b string + var c pgtype.Text + var d string + var e pgtype.Text + + for _, format := range formats { + err := conn.QueryRow(context.Background(), "select row(null,'null',null,'',null)", pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.Nilf(t, a.Get(), "Format: %v", format) + assert.EqualValuesf(t, "null", b, "Format: %v", format) + assert.Nilf(t, c.Get(), "Format: %v", format) + assert.EqualValuesf(t, "", d, "Format: %v", format) + assert.Nilf(t, e.Get(), "Format: %v", format) + } + } + + // null record + { + var a pgtype.Text + var b string + cf := pgtype.CompositeFields{&a, &b} + + for _, format := range formats { + // Cannot scan nil into + err := conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( + cf, + ) + if assert.Errorf(t, err, "Format: %v", format) { + continue + } + assert.NotNilf(t, cf, "Format: %v", format) + + // But can scan nil into *pgtype.CompositeFields + err = conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( + &cf, + ) + if assert.Errorf(t, err, "Format: %v", format) { + continue + } + assert.Nilf(t, cf, "Format: %v", format) + } + } + + // quotes and special characters + { + var a, b, c, d string + + for _, format := range formats { + err := conn.QueryRow(context.Background(), `select row('"', 'foo bar', 'foo''bar', 'baz)bar')`, pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b, &c, &d}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.Equalf(t, `"`, a, "Format: %v", format) + assert.Equalf(t, `foo bar`, b, "Format: %v", format) + assert.Equalf(t, `foo'bar`, c, "Format: %v", format) + assert.Equalf(t, `baz)bar`, d, "Format: %v", format) + } + } + + // arrays + { + var a []string + var b []int64 + + for _, format := range formats { + err := conn.QueryRow(context.Background(), `select row(array['foo', 'bar', 'baz'], array[1,2,3])`, pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.EqualValuesf(t, []string{"foo", "bar", "baz"}, a, "Format: %v", format) + assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format) + } + } +} diff --git a/convert.go b/convert.go index 47227fd5..6e70e82e 100644 --- a/convert.go +++ b/convert.go @@ -433,38 +433,6 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { return nil, false } -// ScanRowValue decodes ROW()'s and composite type -// from src argument using provided decoders. Decoders should match -// order and count of fields of record being decoded. -// -// In practice you can pass pgtype.Value types as decoders, as -// most of them implement BinaryDecoder interface. -// -// ScanRowValue takes ownership of src, caller MUST not use it after call -func ScanRowValue(ci *ConnInfo, src []byte, dst ...interface{}) error { - scanner, err := NewCompositeBinaryScanner(src) - if err != nil { - return err - } - - if len(dst) != scanner.FieldCount() { - return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", scanner.FieldCount(), len(dst)) - } - - for i := 0; scanner.Scan(); i++ { - err := ci.Scan(scanner.OID(), BinaryFormatCode, scanner.Bytes(), dst[i]) - if err != nil { - return err - } - } - - if scanner.Err() != nil { - return scanner.Err() - } - - return nil -} - // EncodeRow builds a binary representation of row values (row(), composite types) func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) { fieldBytes := make([]byte, 0, 128) diff --git a/custom_composite_test.go b/custom_composite_test.go index f6f37ec7..a93a8ad0 100644 --- a/custom_composite_test.go +++ b/custom_composite_test.go @@ -20,7 +20,7 @@ func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs") } - if err := pgtype.ScanRowValue(ci, src, &dst.a, &dst.b); err != nil { + if err := (pgtype.CompositeFields{&dst.a, &dst.b}).DecodeBinary(ci, src); err != nil { return err } diff --git a/record_test.go b/record_test.go index 3794fcd7..240812a6 100644 --- a/record_test.go +++ b/record_test.go @@ -79,54 +79,6 @@ var recordTests = []struct { }, } -// row values are binary compatible with records, so we test our helper -// routines here -func TestScanRowValue(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - for i := 0; i < len(recordTests); i++ { - tt := recordTests[i] - psName := fmt.Sprintf("test%d", i) - _, err := conn.Prepare(context.Background(), psName, tt.sql) - if err != nil { - t.Fatal(err) - } - t.Run(tt.sql, func(t *testing.T) { - desc := []interface{}{} - for _, f := range tt.expected.Fields { - desc = append(desc, f.(pgtype.BinaryDecoder)) - } - - var raw pgtype.GenericBinary - - if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&raw); err != nil { - t.Error(err) - return - } - - if raw.Status == pgtype.Null { - // ScanRowValue deals with complete rows only, NULL values (but NOT null fields) - // should be handled by the calling code - return - } - - if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err != nil { - t.Error(err) - } - - // borrow fields from a neighbor test, this makes scan always fail - desc = desc[:0] - for _, f := range recordTests[(i+1)%len(recordTests)].expected.Fields { - desc = append(desc, f.(pgtype.BinaryDecoder)) - } - if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err == nil { - t.Error("Matching scan didn't fail, despite fields not mathching query result") - } - }) - } -} - func TestRecordTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn)