2
0

Merge branch 'record-expect' of git://github.com/redbaron/pgtype into redbaron-record-expect

This commit is contained in:
Jack Christensen
2020-05-06 14:43:10 -05:00
10 changed files with 837 additions and 116 deletions
+6
View File
@@ -0,0 +1,6 @@
{
"go.inferGopath": false,
"go.testEnvVars": {
"PGX_TEST_DATABASE": "user=postgres database=pgx_test host=127.0.0.1"
},
}
+78
View File
@@ -0,0 +1,78 @@
package binary
import (
"encoding/binary"
"github.com/jackc/pgio"
errors "golang.org/x/xerrors"
)
type RecordFieldIter struct {
rp int
src []byte
}
// NewRecordFieldIterator creates iterator over binary representation
// of record, aka ROW(), aka Composite
func NewRecordFieldIterator(src []byte) (RecordFieldIter, int, error) {
rp := 0
if len(src[rp:]) < 4 {
return RecordFieldIter{}, 0, errors.Errorf("Record incomplete %v", src)
}
fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
return RecordFieldIter{
rp: rp,
src: src,
}, fieldCount, nil
}
// Next returns next field decoded from record. eof is returned if no
// more fields left to decode.
func (fi *RecordFieldIter) Next() (fieldOID uint32, buf []byte, eof bool, err error) {
if fi.rp == len(fi.src) {
eof = true
return
}
if len(fi.src[fi.rp:]) < 8 {
err = errors.Errorf("Record incomplete %v", fi.src)
return
}
fieldOID = binary.BigEndian.Uint32(fi.src[fi.rp:])
fi.rp += 4
fieldLen := int(int32(binary.BigEndian.Uint32(fi.src[fi.rp:])))
fi.rp += 4
if fieldLen >= 0 {
if len(fi.src[fi.rp:]) < fieldLen {
err = errors.Errorf("Record incomplete rp=%d src=%v", fi.rp, fi.src)
return
}
buf = fi.src[fi.rp : fi.rp+fieldLen]
fi.rp += fieldLen
}
return
}
// RecordStart adds record header to the buf
func RecordStart(buf []byte, fieldCount int) []byte {
return pgio.AppendUint32(buf, uint32(fieldCount))
}
// RecordAdd adds record field to the buf
func RecordAdd(buf []byte, oid uint32, fieldBytes []byte) []byte {
buf = pgio.AppendUint32(buf, oid)
buf = pgio.AppendUint32(buf, uint32(len(fieldBytes)))
buf = append(buf, fieldBytes...)
return buf
}
// RecordAddNull adds null value as a field to the buf
func RecordAddNull(buf []byte, oid uint32) []byte {
return pgio.AppendInt32(buf, int32(-1))
}
+153
View File
@@ -0,0 +1,153 @@
package pgtype
import (
"github.com/jackc/pgtype/binary"
errors "golang.org/x/xerrors"
)
type Composite struct {
fields []Value
Status Status
}
// NewComposite creates a Composite object, which acts as a "schema" for
// SQL composite values.
// To pass Composite as SQL parameter first set it's fields, either by
// passing initialized Value{} instances to NewComposite or by calling
// SetFields method
// To read composite fields back pass result of Scan() method
// to query Scan function.
func NewComposite(fields ...Value) *Composite {
return &Composite{fields, Present}
}
func (src Composite) Get() interface{} {
switch src.Status {
case Present:
return src
case Null:
return nil
default:
return src.Status
}
}
// Set is called internally when passing query arguments.
func (dst *Composite) Set(src interface{}) error {
if src == nil {
*dst = Composite{Status: Null}
return nil
}
switch value := src.(type) {
case []Value:
if len(value) != len(dst.fields) {
return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields))
}
for i, v := range value {
if err := dst.fields[i].Set(v); err != nil {
return err
}
}
dst.Status = Present
default:
return errors.Errorf("Can not convert %v to Composite", src)
}
return nil
}
// AssignTo should never be called on composite value directly
func (src Composite) AssignTo(dst interface{}) error {
return errors.New("Pass Composite.Scan() to deconstruct composite")
}
func (src Composite) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
return EncodeRow(ci, buf, src.fields...)
}
// DecodeBinary implements BinaryDecoder interface.
// Opposite to Record, fields in a composite act as a "schema"
// and decoding fails if SQL value can't be assigned due to
// type mismatch
func (dst *Composite) DecodeBinary(ci *ConnInfo, buf []byte) (err error) {
if buf == nil {
dst.Status = Null
return nil
}
fieldIter, fieldCount, err := binary.NewRecordFieldIterator(buf)
if err != nil {
return err
} else if len(dst.fields) != fieldCount {
return errors.Errorf("SQL composite can't be read, field count mismatch. expected %d , found %d", len(dst.fields), fieldCount)
}
_, fieldBytes, eof, err := fieldIter.Next()
for i := 0; !eof; i++ {
if err != nil {
return err
}
binaryDecoder, ok := dst.fields[i].(BinaryDecoder)
if !ok {
return errors.New("Composite field doesn't support binary protocol")
}
if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil {
return err
}
_, fieldBytes, eof, err = fieldIter.Next()
}
dst.Status = Present
return nil
}
// Scan is a helper function to perform "nested" scan of
// a composite value when scanning a query result row.
// isNull is set if scanned value is NULL
// Rest of arguments are set in the order of fields in the composite
//
// Use of Scan method doesn't modify original composite
func (src Composite) Scan(isNull *bool, dst ...interface{}) BinaryDecoderFunc {
return func(ci *ConnInfo, buf []byte) error {
if err := src.DecodeBinary(ci, buf); err != nil {
return err
}
if src.Status == Null {
*isNull = true
return nil
}
for i, f := range src.fields {
if err := f.AssignTo(dst[i]); err != nil {
return err
}
}
return nil
}
}
// SetFields sets Composite's fields to corresponding values
func (dst *Composite) SetFields(values ...interface{}) error {
if len(values) != len(dst.fields) {
return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields))
}
for i, v := range values {
if err := dst.fields[i].Set(v); err != nil {
return err
}
}
dst.Status = Present
return nil
}
+196
View File
@@ -0,0 +1,196 @@
package pgtype_test
import (
"testing"
"github.com/jackc/pgtype"
"github.com/jackc/pgtype/binary"
errors "golang.org/x/xerrors"
)
type MyCompositeRaw struct {
A int32
B *string
}
func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) {
a := pgtype.Int4{src.A, pgtype.Present}
fieldBytes := make([]byte, 0, 64)
fieldBytes, _ = a.EncodeBinary(ci, fieldBytes[:0])
newBuf = binary.RecordStart(buf, 2)
newBuf = binary.RecordAdd(newBuf, pgtype.Int4OID, fieldBytes)
if src.B != nil {
fieldBytes, _ = pgtype.Text{*src.B, pgtype.Present}.EncodeBinary(ci, fieldBytes[:0])
newBuf = binary.RecordAdd(newBuf, pgtype.TextOID, fieldBytes)
} else {
newBuf = binary.RecordAddNull(newBuf, pgtype.TextOID)
}
return
}
func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
a := pgtype.Int4{}
b := pgtype.Text{}
fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src)
if err != nil {
return err
}
if 2 != fieldCount {
return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=2", fieldCount)
}
_, fieldBytes, eof, err := fieldIter.Next()
if eof || err != nil {
return errors.New("Bad record")
}
if err = a.DecodeBinary(ci, fieldBytes); err != nil {
return err
}
_, fieldBytes, eof, err = fieldIter.Next()
if eof || err != nil {
return errors.New("Bad record")
}
if err = b.DecodeBinary(ci, fieldBytes); err != nil {
return err
}
dst.A = a.Int
if b.Status == pgtype.Present {
dst.B = &b.String
} else {
dst.B = nil
}
return nil
}
var x []byte
func BenchmarkBinaryEncodingManual(b *testing.B) {
buf := make([]byte, 0, 128)
ci := pgtype.NewConnInfo()
v := MyCompositeRaw{4, ptrS("ABCDEFG")}
b.ResetTimer()
for n := 0; n < b.N; n++ {
buf, _ = v.EncodeBinary(ci, buf[:0])
}
x = buf
}
func BenchmarkBinaryEncodingHelper(b *testing.B) {
buf := make([]byte, 0, 128)
ci := pgtype.NewConnInfo()
v := MyType{4, ptrS("ABCDEFG")}
b.ResetTimer()
for n := 0; n < b.N; n++ {
buf, _ = v.EncodeBinary(ci, buf[:0])
}
x = buf
}
func BenchmarkBinaryEncodingComposite(b *testing.B) {
buf := make([]byte, 0, 128)
ci := pgtype.NewConnInfo()
f1 := 2
f2 := ptrS("bar")
c := pgtype.NewComposite(&pgtype.Int4{}, &pgtype.Text{})
b.ResetTimer()
for n := 0; n < b.N; n++ {
c.SetFields(f1, f2)
buf, _ = c.EncodeBinary(ci, buf[:0])
}
x = buf
}
func BenchmarkBinaryEncodingJSON(b *testing.B) {
buf := make([]byte, 0, 128)
ci := pgtype.NewConnInfo()
v := MyCompositeRaw{4, ptrS("ABCDEFG")}
j := pgtype.JSON{}
b.ResetTimer()
for n := 0; n < b.N; n++ {
j.Set(v)
buf, _ = j.EncodeBinary(ci, buf[:0])
}
x = buf
}
var dstRaw MyCompositeRaw
func BenchmarkBinaryDecodingManual(b *testing.B) {
ci := pgtype.NewConnInfo()
buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil)
dst := MyCompositeRaw{}
b.ResetTimer()
for n := 0; n < b.N; n++ {
err := dst.DecodeBinary(ci, buf)
E(err)
}
dstRaw = dst
}
var dstMyType MyType
func BenchmarkBinaryDecodingHelpers(b *testing.B) {
ci := pgtype.NewConnInfo()
buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil)
dst := MyType{}
b.ResetTimer()
for n := 0; n < b.N; n++ {
err := dst.DecodeBinary(ci, buf)
E(err)
}
dstMyType = dst
}
var gf1 int
var gf2 *string
func BenchmarkBinaryDecodingCompositeScan(b *testing.B) {
ci := pgtype.NewConnInfo()
buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil)
var isNull bool
var f1 int
var f2 *string
c := pgtype.NewComposite(&pgtype.Int4{}, &pgtype.Text{})
b.ResetTimer()
for n := 0; n < b.N; n++ {
err := c.Scan(&isNull, &f1, &f2).DecodeBinary(ci, buf)
E(err)
}
gf1 = f1
gf2 = f2
}
func BenchmarkBinaryDecodingJSON(b *testing.B) {
ci := pgtype.NewConnInfo()
j := pgtype.JSON{}
j.Set(MyCompositeRaw{4, ptrS("ABCDEFG")})
buf, _ := j.EncodeBinary(ci, nil)
j = pgtype.JSON{}
dst := MyCompositeRaw{}
b.ResetTimer()
for n := 0; n < b.N; n++ {
err := j.DecodeBinary(ci, buf)
E(err)
err = j.AssignTo(&dst)
E(err)
}
dstRaw = dst
}
+57
View File
@@ -0,0 +1,57 @@
package pgtype_test
import (
"context"
"fmt"
"os"
"github.com/jackc/pgtype"
pgx "github.com/jackc/pgx/v4"
)
//ExampleComposite demonstrates use of Row() function to pass and receive
// back composite types without creating boilderplate custom types.
func Example_composite() {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
E(err)
defer conn.Close(context.Background())
_, err = conn.Exec(context.Background(), `drop type if exists mytype;
create type mytype as (
a int4,
b text
);`)
E(err)
defer conn.Exec(context.Background(), "drop type mytype")
qrf := pgx.QueryResultFormats{pgx.BinaryFormatCode}
var isNull bool
var a int
var b *string
c := pgtype.NewComposite(&pgtype.Int4{}, &pgtype.Text{})
c.SetFields(2, "bar")
err = conn.QueryRow(context.Background(), "select $1::mytype", qrf, c).
Scan(c.Scan(&isNull, &a, &b))
E(err)
fmt.Printf("First: isNull=%v a=%d b=%s\n", isNull, a, *b)
err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype", qrf).Scan(c.Scan(&isNull, &a, &b))
E(err)
fmt.Printf("Second: isNull=%v a=%d b=%v\n", isNull, a, b)
err = conn.QueryRow(context.Background(), "select NULL::mytype", qrf).Scan(c.Scan(&isNull, &a, &b))
E(err)
fmt.Printf("Third: isNull=%v\n", isNull)
// Output:
// First: isNull=false a=2 b=bar
// Second: isNull=false a=1 b=<nil>
// Third: isNull=true
}
+63
View File
@@ -5,6 +5,7 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/jackc/pgtype/binary"
errors "golang.org/x/xerrors" errors "golang.org/x/xerrors"
) )
@@ -433,6 +434,68 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) {
return nil, false return nil, false
} }
// ScanRowValue decodes ROW()'s and composite type
// from src argument using provided decoders. Decoders should match
// order and count of fields of record being decoded.
//
// In practice you can pass pgtype.Value types as decoders, as
// most of them implement BinaryDecoder interface.
//
// ScanRowValue takes ownership of src, caller MUST not use it after call
func ScanRowValue(ci *ConnInfo, src []byte, dst ...BinaryDecoder) error {
fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src)
if err != nil {
return err
}
if len(dst) != fieldCount {
return errors.Errorf("can't scan row value, number of fields don't match: found=%d expected=%d", fieldCount, len(dst))
}
_, fieldBytes, eof, err := fieldIter.Next()
for i := 0; !eof; i++ {
if err != nil {
return err
}
if err = dst[i].DecodeBinary(ci, fieldBytes); err != nil {
return err
}
_, fieldBytes, eof, err = fieldIter.Next()
}
return nil
}
// EncodeRow builds a binary representation of row values (row(), composite types)
func EncodeRow(ci *ConnInfo, buf []byte, fields ...Value) (newBuf []byte, err error) {
fieldBytes := make([]byte, 0, 128)
newBuf = binary.RecordStart(buf, len(fields))
for _, f := range fields {
dt, ok := ci.DataTypeForValue(f)
if !ok {
return nil, errors.Errorf("Unknown OID for %s", f)
}
if f.Get() != nil {
binaryEncoder, ok := f.(BinaryEncoder)
if !ok {
return nil, errors.Errorf("record field doesn't implement binary encoding: %s", reflect.TypeOf(f).Name())
}
fieldBytes, err = binaryEncoder.EncodeBinary(ci, fieldBytes[:0])
if err != nil {
return nil, err
}
newBuf = binary.RecordAdd(newBuf, dt.OID, fieldBytes)
} else {
newBuf = binary.RecordAddNull(newBuf, dt.OID)
}
}
return
}
func init() { func init() {
kindTypes = map[reflect.Kind]reflect.Type{ kindTypes = map[reflect.Kind]reflect.Type{
reflect.Bool: reflect.TypeOf(false), reflect.Bool: reflect.TypeOf(false),
+101
View File
@@ -0,0 +1,101 @@
package pgtype_test
import (
"context"
"fmt"
"os"
"github.com/jackc/pgtype"
pgx "github.com/jackc/pgx/v4"
errors "golang.org/x/xerrors"
)
type MyType struct {
a int32 // NULL will cause decoding error
b *string // there can be NULL in this position in SQL
}
func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
if src == nil {
return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs")
}
a := pgtype.Int4{}
b := pgtype.Text{}
if err := pgtype.ScanRowValue(ci, src, &a, &b); err != nil {
return err
}
// type compatibility is checked by AssignTo
// only lossless assignments will succeed
if err := a.AssignTo(&dst.a); err != nil {
return err
}
// AssignTo also deals with null value handling
if err := b.AssignTo(&dst.b); err != nil {
return err
}
return nil
}
func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) {
a := pgtype.Int4{src.a, pgtype.Present}
var b pgtype.Text
if src.b != nil {
b = pgtype.Text{*src.b, pgtype.Present}
} else {
b = pgtype.Text{Status: pgtype.Null}
}
return pgtype.EncodeRow(ci, buf, &a, &b)
}
func ptrS(s string) *string {
return &s
}
func E(err error) {
if err != nil {
panic(err)
}
}
// ExampleCustomCompositeTypes demonstrates how support for custom types mappable to SQL
// composites can be added.
func Example_customCompositeTypes() {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
E(err)
defer conn.Close(context.Background())
_, err = conn.Exec(context.Background(), `drop type if exists mytype;
create type mytype as (
a int4,
b text
);`)
E(err)
defer conn.Exec(context.Background(), "drop type mytype")
var result *MyType
// Demonstrates both passing and reading back composite values
err = conn.QueryRow(context.Background(), "select $1::mytype",
pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}).
Scan(&result)
E(err)
fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b)
// Because we scan into &*MyType, NULLs are handled generically by assigning nil to result
err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result)
E(err)
fmt.Printf("Second row: %v\n", result)
// Output:
// First row: a=1 b=foo
// Second row: <nil>
}
+18
View File
@@ -174,6 +174,24 @@ type TextEncoder interface {
EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error)
} }
//The BinaryDecoderFunc type is an adapter to allow the use of ordinary functions as BinaryDecoder types.
// If f is a function with the appropriate signature, BinaryDecoderFunc(f) is a BinaryDecoder that calls f.
type BinaryDecoderFunc func(ci *ConnInfo, src []byte) error
// DecodeBinary calls f(ci, src)
func (f BinaryDecoderFunc) DecodeBinary(ci *ConnInfo, src []byte) error {
return f(ci, src)
}
//The BinaryEncoderFunc type is an adapter to allow the use of ordinary functions as BinaryDecoder types.
// If f is a function with the appropriate signature, BinaryEncoderFunc(f) is a BinaryDecoder that calls f.
type BinaryEncoderFunc func(ci *ConnInfo, buf []byte) ([]byte, error)
// EncodeBinary calls f(ci, buf)
func (f BinaryEncoderFunc) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) {
return f(ci, buf)
}
var errUndefined = errors.New("cannot encode status undefined") var errUndefined = errors.New("cannot encode status undefined")
var errBadStatus = errors.New("invalid status") var errBadStatus = errors.New("invalid status")
+37 -39
View File
@@ -1,9 +1,10 @@
package pgtype package pgtype
import ( import (
"encoding/binary"
"reflect" "reflect"
"github.com/jackc/pgtype/binary"
errors "golang.org/x/xerrors" errors "golang.org/x/xerrors"
) )
@@ -78,57 +79,54 @@ func (src *Record) AssignTo(dst interface{}) error {
return errors.Errorf("cannot decode %#v into %T", src, dst) return errors.Errorf("cannot decode %#v into %T", src, dst)
} }
func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDecoder, error) {
var binaryDecoder BinaryDecoder
if dt, ok := ci.DataTypeForOID(fieldOID); ok {
binaryDecoder, _ = dt.Value.(BinaryDecoder)
} else {
return nil, errors.Errorf("unknown oid while decoding record: %v", fieldOID)
}
if binaryDecoder == nil {
return nil, errors.Errorf("no binary decoder registered for: %v", fieldOID)
}
// Duplicate struct to scan into
binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder)
*v = binaryDecoder.(Value)
return binaryDecoder, nil
}
func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil { if src == nil {
*dst = Record{Status: Null} *dst = Record{Status: Null}
return nil return nil
} }
rp := 0 fieldIter, fieldCount, err := binary.NewRecordFieldIterator(src)
if err != nil {
if len(src[rp:]) < 4 { return err
return errors.Errorf("Record incomplete %v", src)
} }
fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
fields := make([]Value, fieldCount) fields := make([]Value, fieldCount)
fieldOID, fieldBytes, eof, err := fieldIter.Next()
for i := 0; i < fieldCount; i++ { for i := 0; !eof; i++ {
if len(src[rp:]) < 8 { if err != nil {
return errors.Errorf("Record incomplete %v", src)
}
fieldOID := binary.BigEndian.Uint32(src[rp:])
rp += 4
fieldLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
var binaryDecoder BinaryDecoder
if dt, ok := ci.DataTypeForOID(fieldOID); ok {
binaryDecoder, _ = dt.Value.(BinaryDecoder)
}
if binaryDecoder == nil {
return errors.Errorf("unknown oid while decoding record: %v", fieldOID)
}
var fieldBytes []byte
if fieldLen >= 0 {
if len(src[rp:]) < fieldLen {
return errors.Errorf("Record incomplete %v", src)
}
fieldBytes = src[rp : rp+fieldLen]
rp += fieldLen
}
// Duplicate struct to scan into
binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder)
if err := binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil {
return err return err
} }
fields[i] = binaryDecoder.(Value) binaryDecoder, err := prepareNewBinaryDecoder(ci, fieldOID, &fields[i])
if err != nil {
return err
}
if err = binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil {
return err
}
fieldOID, fieldBytes, eof, err = fieldIter.Next()
} }
*dst = Record{Fields: fields, Status: Present} *dst = Record{Fields: fields, Status: Present}
+128 -77
View File
@@ -11,94 +11,145 @@ import (
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
) )
var recordTests = []struct {
sql string
expected pgtype.Record
}{
{
sql: `select row()`,
expected: pgtype.Record{
Fields: []pgtype.Value{},
Status: pgtype.Present,
},
},
{
sql: `select row('foo'::text, 42::int4)`,
expected: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Text{String: "foo", Status: pgtype.Present},
&pgtype.Int4{Int: 42, Status: pgtype.Present},
},
Status: pgtype.Present,
},
},
{
sql: `select row(100.0::float4, 1.09::float4)`,
expected: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Float4{Float: 100, Status: pgtype.Present},
&pgtype.Float4{Float: 1.09, Status: pgtype.Present},
},
Status: pgtype.Present,
},
},
{
sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`,
expected: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Text{String: "foo", Status: pgtype.Present},
&pgtype.Int4Array{
Elements: []pgtype.Int4{
{Int: 1, Status: pgtype.Present},
{Int: 2, Status: pgtype.Present},
{Status: pgtype.Null},
{Int: 4, Status: pgtype.Present},
},
Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}},
Status: pgtype.Present,
},
&pgtype.Int4{Int: 42, Status: pgtype.Present},
},
Status: pgtype.Present,
},
},
{
sql: `select row(null)`,
expected: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Unknown{Status: pgtype.Null},
},
Status: pgtype.Present,
},
},
{
sql: `select null::record`,
expected: pgtype.Record{
Status: pgtype.Null,
},
},
}
// row values are binary compatible with records, so we test our helper
// routines here
func TestScanRowValue(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
for i := 0; i < len(recordTests); i++ {
tt := recordTests[i]
psName := fmt.Sprintf("test%d", i)
_, err := conn.Prepare(context.Background(), psName, tt.sql)
if err != nil {
t.Fatal(err)
}
t.Run(tt.sql, func(t *testing.T) {
desc := []pgtype.BinaryDecoder{}
for _, f := range tt.expected.Fields {
desc = append(desc, f.(pgtype.BinaryDecoder))
}
var raw pgtype.GenericBinary
if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&raw); err != nil {
t.Error(err)
return
}
if raw.Status == pgtype.Null {
// ScanRowValue deals with complete rows only, NULL values (but NOT null fields)
// should be handled by the calling code
return
}
if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err != nil {
t.Error(err)
}
// borrow fields from a neighbor test, this makes scan always fail
desc = desc[:0]
for _, f := range recordTests[(i+1)%len(recordTests)].expected.Fields {
desc = append(desc, f.(pgtype.BinaryDecoder))
}
if err := pgtype.ScanRowValue(conn.ConnInfo(), raw.Bytes, desc...); err == nil {
t.Error("Matching scan didn't fail, despite fields not mathching query result")
}
})
}
}
func TestRecordTranscode(t *testing.T) { func TestRecordTranscode(t *testing.T) {
conn := testutil.MustConnectPgx(t) conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn) defer testutil.MustCloseContext(t, conn)
tests := []struct { for i, tt := range recordTests {
sql string
expected pgtype.Record
}{
{
sql: `select row()`,
expected: pgtype.Record{
Fields: []pgtype.Value{},
Status: pgtype.Present,
},
},
{
sql: `select row('foo'::text, 42::int4)`,
expected: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Text{String: "foo", Status: pgtype.Present},
&pgtype.Int4{Int: 42, Status: pgtype.Present},
},
Status: pgtype.Present,
},
},
{
sql: `select row(100.0::float4, 1.09::float4)`,
expected: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Float4{Float: 100, Status: pgtype.Present},
&pgtype.Float4{Float: 1.09, Status: pgtype.Present},
},
Status: pgtype.Present,
},
},
{
sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`,
expected: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Text{String: "foo", Status: pgtype.Present},
&pgtype.Int4Array{
Elements: []pgtype.Int4{
{Int: 1, Status: pgtype.Present},
{Int: 2, Status: pgtype.Present},
{Status: pgtype.Null},
{Int: 4, Status: pgtype.Present},
},
Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}},
Status: pgtype.Present,
},
&pgtype.Int4{Int: 42, Status: pgtype.Present},
},
Status: pgtype.Present,
},
},
{
sql: `select row(null)`,
expected: pgtype.Record{
Fields: []pgtype.Value{
&pgtype.Unknown{Status: pgtype.Null},
},
Status: pgtype.Present,
},
},
{
sql: `select null::record`,
expected: pgtype.Record{
Status: pgtype.Null,
},
},
}
for i, tt := range tests {
psName := fmt.Sprintf("test%d", i) psName := fmt.Sprintf("test%d", i)
_, err := conn.Prepare(context.Background(), psName, tt.sql) _, err := conn.Prepare(context.Background(), psName, tt.sql)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var result pgtype.Record t.Run(tt.sql, func(t *testing.T) {
if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { var result pgtype.Record
t.Errorf("%d: %v", i, err) if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil {
continue t.Errorf("%v", err)
} return
}
if !reflect.DeepEqual(tt.expected, result) {
t.Errorf("expected %#v, got %#v", tt.expected, result)
}
})
if !reflect.DeepEqual(tt.expected, result) {
t.Errorf("%d: expected %#v, got %#v", i, tt.expected, result)
}
} }
} }