diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..a32b4d68 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "go.inferGopath": false, + "go.testEnvVars": { + "PGX_TEST_DATABASE": "user=postgres database=pgx_test host=127.0.0.1" + }, +} \ No newline at end of file diff --git a/binary/record.go b/binary/record.go new file mode 100644 index 00000000..72b688a8 --- /dev/null +++ b/binary/record.go @@ -0,0 +1,78 @@ +package binary + +import ( + "encoding/binary" + + "github.com/jackc/pgio" + errors "golang.org/x/xerrors" +) + +type RecordFieldIter struct { + rp int + src []byte +} + +// NewRecordFieldIterator creates iterator over binary representation +// of record, aka ROW(), aka Composite +func NewRecordFieldIterator(src []byte) (RecordFieldIter, int, error) { + rp := 0 + if len(src[rp:]) < 4 { + return RecordFieldIter{}, 0, errors.Errorf("Record incomplete %v", src) + } + + fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + return RecordFieldIter{ + rp: rp, + src: src, + }, fieldCount, nil +} + +// Next returns next field decoded from record. eof is returned if no +// more fields left to decode. +func (fi *RecordFieldIter) Next() (fieldOID uint32, buf []byte, eof bool, err error) { + if fi.rp == len(fi.src) { + eof = true + return + } + + if len(fi.src[fi.rp:]) < 8 { + err = errors.Errorf("Record incomplete %v", fi.src) + return + } + fieldOID = binary.BigEndian.Uint32(fi.src[fi.rp:]) + fi.rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(fi.src[fi.rp:]))) + fi.rp += 4 + + if fieldLen >= 0 { + if len(fi.src[fi.rp:]) < fieldLen { + err = errors.Errorf("Record incomplete rp=%d src=%v", fi.rp, fi.src) + return + } + buf = fi.src[fi.rp : fi.rp+fieldLen] + fi.rp += fieldLen + } + + return +} + +// RecordStart adds record header to the buf +func RecordStart(buf []byte, fieldCount int) []byte { + return pgio.AppendUint32(buf, uint32(fieldCount)) +} + +// RecordAdd adds record field to the buf +func RecordAdd(buf []byte, oid uint32, fieldBytes []byte) []byte { + buf = pgio.AppendUint32(buf, oid) + buf = pgio.AppendUint32(buf, uint32(len(fieldBytes))) + buf = append(buf, fieldBytes...) + return buf +} + +// RecordAddNull adds null value as a field to the buf +func RecordAddNull(buf []byte, oid uint32) []byte { + return pgio.AppendInt32(buf, int32(-1)) +} diff --git a/composite.go b/composite.go new file mode 100644 index 00000000..61034262 --- /dev/null +++ b/composite.go @@ -0,0 +1,153 @@ +package pgtype + +import ( + "github.com/jackc/pgtype/binary" + errors "golang.org/x/xerrors" +) + +type Composite struct { + fields []Value + Status Status +} + +// NewComposite creates a Composite object, which acts as a "schema" for +// SQL composite values. +// To pass Composite as SQL parameter first set it's fields, either by +// passing initialized Value{} instances to NewComposite or by calling +// SetFields method +// To read composite fields back pass result of Scan() method +// to query Scan function. +func NewComposite(fields ...Value) *Composite { + return &Composite{fields, Present} +} + +func (src Composite) Get() interface{} { + switch src.Status { + case Present: + return src + case Null: + return nil + default: + return src.Status + } +} + +// Set is called internally when passing query arguments. +func (dst *Composite) Set(src interface{}) error { + if src == nil { + *dst = Composite{Status: Null} + return nil + } + + switch value := src.(type) { + case []Value: + if len(value) != len(dst.fields) { + return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields)) + } + for i, v := range value { + if err := dst.fields[i].Set(v); err != nil { + return err + } + } + dst.Status = Present + default: + return errors.Errorf("Can not convert %v to Composite", src) + } + + return nil +} + +// AssignTo should never be called on composite value directly +func (src Composite) AssignTo(dst interface{}) error { + return errors.New("Pass Composite.Scan() to deconstruct composite") +} + +func (src Composite) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + return EncodeRow(ci, buf, src.fields...) +} + +// DecodeBinary implements BinaryDecoder interface. +// Opposite to Record, fields in a composite act as a "schema" +// and decoding fails if SQL value can't be assigned due to +// type mismatch +func (dst *Composite) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { + if buf == nil { + dst.Status = Null + return nil + } + + fieldIter, fieldCount, err := binary.NewRecordFieldIterator(buf) + if err != nil { + return err + } else if len(dst.fields) != fieldCount { + return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(dst.fields), fieldCount) + } + + _, fieldBytes, eof, err := fieldIter.Next() + + for i := 0; !eof; i++ { + if err != nil { + return err + } + + binaryDecoder, ok := dst.fields[i].(BinaryDecoder) + if !ok { + return errors.New("Composite field doesn't support binary protocol") + } + + if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + _, fieldBytes, eof, err = fieldIter.Next() + } + dst.Status = Present + + return nil +} + +// Scan is a helper function to perform "nested" scan of +// a composite value when scanning a query result row. +// isNull is set if scanned value is NULL +// Rest of arguments are set in the order of fields in the composite +// +// Use of Scan method doesn't modify original composite +func (src Composite) Scan(isNull *bool, dst ...interface{}) BinaryDecoderFunc { + return func(ci *ConnInfo, buf []byte) error { + if err := src.DecodeBinary(ci, buf); err != nil { + return err + } + + if src.Status == Null { + *isNull = true + return nil + } + + for i, f := range src.fields { + if err := f.AssignTo(dst[i]); err != nil { + return err + } + } + return nil + } +} + +// SetFields sets Composite's fields to corresponding values +func (dst *Composite) SetFields(values ...interface{}) error { + if len(values) != len(dst.fields) { + return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields)) + } + for i, v := range values { + if err := dst.fields[i].Set(v); err != nil { + return err + } + } + dst.Status = Present + return nil +} diff --git a/composite_bench_test.go b/composite_bench_test.go new file mode 100644 index 00000000..429ce9b3 --- /dev/null +++ b/composite_bench_test.go @@ -0,0 +1,196 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/binary" + errors "golang.org/x/xerrors" +) + +type MyCompositeRaw struct { + A int32 + B *string +} + +func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { + a := pgtype.Int4{src.A, pgtype.Present} + + fieldBytes := make([]byte, 0, 64) + fieldBytes, _ = a.EncodeBinary(ci, fieldBytes[:0]) + + newBuf = binary.RecordStart(buf, 2) + newBuf = binary.RecordAdd(newBuf, pgtype.Int4OID, fieldBytes) + + if src.B != nil { + fieldBytes, _ = pgtype.Text{*src.B, pgtype.Present}.EncodeBinary(ci, fieldBytes[:0]) + newBuf = binary.RecordAdd(newBuf, pgtype.TextOID, fieldBytes) + } else { + newBuf = binary.RecordAddNull(newBuf, pgtype.TextOID) + } + return +} + +func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + a := pgtype.Int4{} + b := pgtype.Text{} + + fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) + if err != nil { + return err + } + + if 2 != fieldCount { + return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=2", fieldCount) + } + + _, fieldBytes, eof, err := fieldIter.Next() + if eof || err != nil { + return errors.New("Bad record") + } + if err = a.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + _, fieldBytes, eof, err = fieldIter.Next() + if eof || err != nil { + return errors.New("Bad record") + } + if err = b.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + dst.A = a.Int + if b.Status == pgtype.Present { + dst.B = &b.String + } else { + dst.B = nil + } + + return nil +} + +var x []byte + +func BenchmarkBinaryEncodingManual(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + v := MyCompositeRaw{4, ptrS("ABCDEFG")} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + buf, _ = v.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +func BenchmarkBinaryEncodingHelper(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + v := MyType{4, ptrS("ABCDEFG")} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + buf, _ = v.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +func BenchmarkBinaryEncodingComposite(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + f1 := 2 + f2 := ptrS("bar") + c := pgtype.NewComposite(&pgtype.Int4{}, &pgtype.Text{}) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + c.SetFields(f1, f2) + buf, _ = c.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +func BenchmarkBinaryEncodingJSON(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + v := MyCompositeRaw{4, ptrS("ABCDEFG")} + j := pgtype.JSON{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + j.Set(v) + buf, _ = j.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +var dstRaw MyCompositeRaw + +func BenchmarkBinaryDecodingManual(b *testing.B) { + ci := pgtype.NewConnInfo() + buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) + dst := MyCompositeRaw{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := dst.DecodeBinary(ci, buf) + E(err) + } + dstRaw = dst +} + +var dstMyType MyType + +func BenchmarkBinaryDecodingHelpers(b *testing.B) { + ci := pgtype.NewConnInfo() + buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) + dst := MyType{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := dst.DecodeBinary(ci, buf) + E(err) + } + dstMyType = dst +} + +var gf1 int +var gf2 *string + +func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { + ci := pgtype.NewConnInfo() + buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) + var isNull bool + var f1 int + var f2 *string + + c := pgtype.NewComposite(&pgtype.Int4{}, &pgtype.Text{}) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := c.Scan(&isNull, &f1, &f2).DecodeBinary(ci, buf) + E(err) + } + gf1 = f1 + gf2 = f2 +} + +func BenchmarkBinaryDecodingJSON(b *testing.B) { + ci := pgtype.NewConnInfo() + j := pgtype.JSON{} + j.Set(MyCompositeRaw{4, ptrS("ABCDEFG")}) + buf, _ := j.EncodeBinary(ci, nil) + + j = pgtype.JSON{} + dst := MyCompositeRaw{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := j.DecodeBinary(ci, buf) + E(err) + err = j.AssignTo(&dst) + E(err) + } + dstRaw = dst +} diff --git a/composite_test.go b/composite_test.go new file mode 100644 index 00000000..ac0eb4d0 --- /dev/null +++ b/composite_test.go @@ -0,0 +1,57 @@ +package pgtype_test + +import ( + "context" + "fmt" + "os" + + "github.com/jackc/pgtype" + pgx "github.com/jackc/pgx/v4" +) + +//ExampleComposite demonstrates use of Row() function to pass and receive +// back composite types without creating boilderplate custom types. +func Example_composite() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + E(err) + + defer conn.Close(context.Background()) + _, err = conn.Exec(context.Background(), `drop type if exists mytype; + +create type mytype as ( + a int4, + b text +);`) + E(err) + defer conn.Exec(context.Background(), "drop type mytype") + + qrf := pgx.QueryResultFormats{pgx.BinaryFormatCode} + + var isNull bool + var a int + var b *string + + c := pgtype.NewComposite(&pgtype.Int4{}, &pgtype.Text{}) + c.SetFields(2, "bar") + + err = conn.QueryRow(context.Background(), "select $1::mytype", qrf, c). + Scan(c.Scan(&isNull, &a, &b)) + E(err) + + fmt.Printf("First: isNull=%v a=%d b=%s\n", isNull, a, *b) + + err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype", qrf).Scan(c.Scan(&isNull, &a, &b)) + E(err) + + fmt.Printf("Second: isNull=%v a=%d b=%v\n", isNull, a, b) + + err = conn.QueryRow(context.Background(), "select NULL::mytype", qrf).Scan(c.Scan(&isNull, &a, &b)) + E(err) + + fmt.Printf("Third: isNull=%v\n", isNull) + + // Output: + // First: isNull=false a=2 b=bar + // Second: isNull=false a=1 b= + // Third: isNull=true +} diff --git a/convert.go b/convert.go index cc5c10ab..8008d677 100644 --- a/convert.go +++ b/convert.go @@ -5,6 +5,7 @@ import ( "reflect" "time" + "github.com/jackc/pgtype/binary" errors "golang.org/x/xerrors" ) @@ -433,6 +434,68 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { return nil, false } +// ScanRowValue decodes ROW()'s and composite type +// from src argument using provided decoders. Decoders should match +// order and count of fields of record being decoded. +// +// In practice you can pass pgtype.Value types as decoders, as +// most of them implement BinaryDecoder interface. +// +// ScanRowValue takes ownership of src, caller MUST not use it after call +func ScanRowValue(ci *ConnInfo, src []byte, dst ...BinaryDecoder) error { + fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) + if err != nil { + return err + } + + if len(dst) != fieldCount { + return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", fieldCount, len(dst)) + } + + _, fieldBytes, eof, err := fieldIter.Next() + for i := 0; !eof; i++ { + if err != nil { + return err + } + + if err = dst[i].DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + _, fieldBytes, eof, err = fieldIter.Next() + } + + return nil +} + +// EncodeRow builds a binary representation of row values (row(), composite types) +func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) { + fieldBytes := make([]byte, 0, 128) + + newBuf = binary.RecordStart(buf, len(fields)) + for _, f := range fields { + dt, ok := ci.DataTypeForValue(f) + if !ok { + return nil, errors.Errorf("Unknown OID for %s", f) + } + if f.Get() != nil { + binaryEncoder, ok := f.(BinaryEncoder) + if !ok { + return nil, errors.Errorf("record field doesn't implement binary encoding: %s", reflect.TypeOf(f).Name()) + } + fieldBytes, err = binaryEncoder.EncodeBinary(ci, fieldBytes[:0]) + if err != nil { + return nil, err + } + newBuf = binary.RecordAdd(newBuf, dt.OID, fieldBytes) + } else { + newBuf = binary.RecordAddNull(newBuf, dt.OID) + } + + } + return +} + func init() { kindTypes = map[reflect.Kind]reflect.Type{ reflect.Bool: reflect.TypeOf(false), diff --git a/custom_composite_test.go b/custom_composite_test.go new file mode 100644 index 00000000..61ea91c5 --- /dev/null +++ b/custom_composite_test.go @@ -0,0 +1,101 @@ +package pgtype_test + +import ( + "context" + "fmt" + "os" + + "github.com/jackc/pgtype" + pgx "github.com/jackc/pgx/v4" + errors "golang.org/x/xerrors" +) + +type MyType struct { + a int32 // NULL will cause decoding error + b *string // there can be NULL in this position in SQL +} + +func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs") + } + + a := pgtype.Int4{} + b := pgtype.Text{} + + if err := pgtype.ScanRowValue(ci, src, &a, &b); err != nil { + return err + } + + // type compatibility is checked by AssignTo + // only lossless assignments will succeed + if err := a.AssignTo(&dst.a); err != nil { + return err + } + + // AssignTo also deals with null value handling + if err := b.AssignTo(&dst.b); err != nil { + return err + } + + return nil +} + +func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { + a := pgtype.Int4{src.a, pgtype.Present} + var b pgtype.Text + if src.b != nil { + b = pgtype.Text{*src.b, pgtype.Present} + } else { + b = pgtype.Text{Status: pgtype.Null} + } + + return pgtype.EncodeRow(ci, buf, &a, &b) +} + +func ptrS(s string) *string { + return &s +} + +func E(err error) { + if err != nil { + panic(err) + } +} + +// ExampleCustomCompositeTypes demonstrates how support for custom types mappable to SQL +// composites can be added. +func Example_customCompositeTypes() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + E(err) + + defer conn.Close(context.Background()) + _, err = conn.Exec(context.Background(), `drop type if exists mytype; + +create type mytype as ( + a int4, + b text +);`) + E(err) + defer conn.Exec(context.Background(), "drop type mytype") + + var result *MyType + + // Demonstrates both passing and reading back composite values + err = conn.QueryRow(context.Background(), "select $1::mytype", + pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}). + Scan(&result) + E(err) + + fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b) + + // Because we scan into &*MyType, NULLs are handled generically by assigning nil to result + err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result) + E(err) + + fmt.Printf("Second row: %v\n", result) + + // Output: + // First row: a=1 b=foo + // Second row: +} diff --git a/pgtype.go b/pgtype.go index c002150c..d0d4885c 100644 --- a/pgtype.go +++ b/pgtype.go @@ -174,6 +174,24 @@ type TextEncoder interface { EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) } +//The BinaryDecoderFunc type is an adapter to allow the use of ordinary functions as BinaryDecoder types. +// If f is a function with the appropriate signature, BinaryDecoderFunc(f) is a BinaryDecoder that calls f. +type BinaryDecoderFunc func(ci *ConnInfo, src []byte) error + +// DecodeBinary calls f(ci, src) +func (f BinaryDecoderFunc) DecodeBinary(ci *ConnInfo, src []byte) error { + return f(ci, src) +} + +//The BinaryEncoderFunc type is an adapter to allow the use of ordinary functions as BinaryDecoder types. +// If f is a function with the appropriate signature, BinaryEncoderFunc(f) is a BinaryDecoder that calls f. +type BinaryEncoderFunc func(ci *ConnInfo, buf []byte) ([]byte, error) + +// EncodeBinary calls f(ci, buf) +func (f BinaryEncoderFunc) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { + return f(ci, buf) +} + var errUndefined = errors.New("cannot encode status undefined") var errBadStatus = errors.New("invalid status") diff --git a/record.go b/record.go index 5c9d7a02..4e39f92a 100644 --- a/record.go +++ b/record.go @@ -1,9 +1,10 @@ package pgtype import ( - "encoding/binary" "reflect" + "github.com/jackc/pgtype/binary" + errors "golang.org/x/xerrors" ) @@ -78,57 +79,54 @@ func (src *Record) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDecoder, error) { + var binaryDecoder BinaryDecoder + + if dt, ok := ci.DataTypeForOID(fieldOID); ok { + binaryDecoder, _ = dt.Value.(BinaryDecoder) + } else { + return nil, errors.Errorf("unknown oid while decoding record: %v", fieldOID) + } + + if binaryDecoder == nil { + return nil, errors.Errorf("no binary decoder registered for: %v", fieldOID) + } + + // Duplicate struct to scan into + binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder) + *v = binaryDecoder.(Value) + return binaryDecoder, nil +} + func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Record{Status: Null} return nil } - rp := 0 - - if len(src[rp:]) < 4 { - return errors.Errorf("Record incomplete %v", src) + fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src) + if err != nil { + return err } - fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 fields := make([]Value, fieldCount) + fieldOID, fieldBytes, eof, err := fieldIter.Next() - for i := 0; i < fieldCount; i++ { - if len(src[rp:]) < 8 { - return errors.Errorf("Record incomplete %v", src) - } - fieldOID := binary.BigEndian.Uint32(src[rp:]) - rp += 4 - - fieldLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - - var binaryDecoder BinaryDecoder - if dt, ok := ci.DataTypeForOID(fieldOID); ok { - binaryDecoder, _ = dt.Value.(BinaryDecoder) - } - if binaryDecoder == nil { - return errors.Errorf("unknown oid while decoding record: %v", fieldOID) - } - - var fieldBytes []byte - if fieldLen >= 0 { - if len(src[rp:]) < fieldLen { - return errors.Errorf("Record incomplete %v", src) - } - fieldBytes = src[rp : rp+fieldLen] - rp += fieldLen - } - - // Duplicate struct to scan into - binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder) - - if err := binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + for i := 0; !eof; i++ { + if err != nil { return err } - fields[i] = binaryDecoder.(Value) + binaryDecoder, err := prepareNewBinaryDecoder(ci, fieldOID, &fields[i]) + if err != nil { + return err + } + + if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + fieldOID, fieldBytes, eof, err = fieldIter.Next() } *dst = Record{Fields: fields, Status: Present} diff --git a/record_test.go b/record_test.go index 71a2f702..9516612e 100644 --- a/record_test.go +++ b/record_test.go @@ -11,94 +11,145 @@ import ( "github.com/jackc/pgx/v4" ) +var recordTests = []struct { + sql string + expected pgtype.Record +}{ + { + sql: `select row()`, + expected: pgtype.Record{ + Fields: []pgtype.Value{}, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row(100.0::float4, 1.09::float4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Float4{Float: 100, Status: pgtype.Present}, + &pgtype.Float4{Float: 1.09, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row(null)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Unknown{Status: pgtype.Null}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select null::record`, + expected: pgtype.Record{ + Status: pgtype.Null, + }, + }, +} + +// row values are binary compatible with records, so we test our helper +// routines here +func TestScanRowValue(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for i := 0; i < len(recordTests); i++ { + tt := recordTests[i] + psName := fmt.Sprintf("test%d", i) + _, err := conn.Prepare(context.Background(), psName, tt.sql) + if err != nil { + t.Fatal(err) + } + t.Run(tt.sql, func(t *testing.T) { + desc := []pgtype.BinaryDecoder{} + for _, f := range tt.expected.Fields { + desc = append(desc, f.(pgtype.BinaryDecoder)) + } + + var raw pgtype.GenericBinary + + if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&raw); err != nil { + t.Error(err) + return + } + + if raw.Status == pgtype.Null { + // ScanRowValue deals with complete rows only, NULL values (but NOT null fields) + // should be handled by the calling code + return + } + + if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err != nil { + t.Error(err) + } + + // borrow fields from a neighbor test, this makes scan always fail + desc = desc[:0] + for _, f := range recordTests[(i+1)%len(recordTests)].expected.Fields { + desc = append(desc, f.(pgtype.BinaryDecoder)) + } + if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err == nil { + t.Error("Matching scan didn't fail, despite fields not mathching query result") + } + }) + } +} + func TestRecordTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) - tests := []struct { - sql string - expected pgtype.Record - }{ - { - sql: `select row()`, - expected: pgtype.Record{ - Fields: []pgtype.Value{}, - Status: pgtype.Present, - }, - }, - { - sql: `select row('foo'::text, 42::int4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select row(100.0::float4, 1.09::float4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Float4{Float: 100, Status: pgtype.Present}, - &pgtype.Float4{Float: 1.09, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Text{String: "foo", Status: pgtype.Present}, - &pgtype.Int4Array{ - Elements: []pgtype.Int4{ - {Int: 1, Status: pgtype.Present}, - {Int: 2, Status: pgtype.Present}, - {Status: pgtype.Null}, - {Int: 4, Status: pgtype.Present}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, - Status: pgtype.Present, - }, - &pgtype.Int4{Int: 42, Status: pgtype.Present}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select row(null)`, - expected: pgtype.Record{ - Fields: []pgtype.Value{ - &pgtype.Unknown{Status: pgtype.Null}, - }, - Status: pgtype.Present, - }, - }, - { - sql: `select null::record`, - expected: pgtype.Record{ - Status: pgtype.Null, - }, - }, - } - - for i, tt := range tests { + for i, tt := range recordTests { psName := fmt.Sprintf("test%d", i) _, err := conn.Prepare(context.Background(), psName, tt.sql) if err != nil { t.Fatal(err) } - var result pgtype.Record - if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { - t.Errorf("%d: %v", i, err) - continue - } + t.Run(tt.sql, func(t *testing.T) { + var result pgtype.Record + if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { + t.Errorf("%v", err) + return + } + + if !reflect.DeepEqual(tt.expected, result) { + t.Errorf("expected %#v, got %#v", tt.expected, result) + } + }) - if !reflect.DeepEqual(tt.expected, result) { - t.Errorf("%d: expected %#v, got %#v", i, tt.expected, result) - } } }