diff --git a/composite_bench_test.go b/composite_bench_test.go index d4eb0ac7..9eaf7632 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -160,7 +160,6 @@ var gf2 *string func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { ci := pgtype.NewConnInfo() buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) - var isNull bool var f1 int var f2 *string @@ -168,8 +167,14 @@ func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - err := c.Scan(&isNull, &f1, &f2).DecodeBinary(ci, buf) - E(err) + err := c.DecodeBinary(ci, buf) + if err != nil { + b.Fatal(err) + } + err = c.AssignTo([]interface{}{&f1, &f2}) + if err != nil { + b.Fatal(err) + } } gf1 = f1 gf2 = f2 diff --git a/composite_type.go b/composite_type.go index 76b32b86..53386f37 100644 --- a/composite_type.go +++ b/composite_type.go @@ -70,7 +70,45 @@ func (dst *CompositeType) Set(src interface{}) error { // AssignTo should never be called on composite value directly func (src CompositeType) AssignTo(dst interface{}) error { - return errors.New("Pass Composite.Scan() to deconstruct composite") + switch src.status { + case Present: + switch v := dst.(type) { + case []interface{}: + if len(v) != len(src.fields) { + return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.fields)) + } + for i := range src.fields { + if v[i] == nil { + continue + } + + assignToErr := src.fields[i].AssignTo(v[i]) + if assignToErr != nil { + // Try to use get / set instead -- this avoids every type having to be able to AssignTo type of self. + setSucceeded := false + if setter, ok := v[i].(Value); ok { + err := setter.Set(src.fields[i].Get()) + setSucceeded = err == nil + } + if !setSucceeded { + return errors.Errorf("unable to assign to dst[%d]: %v", i, assignToErr) + } + } + + } + return nil + case *[]interface{}: + return src.AssignTo(*v) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return errors.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + return errors.Errorf("cannot decode %#v into %T", src, dst) } func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { @@ -121,32 +159,6 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) (err error) { return nil } -// Scan is a helper function to perform "nested" scan of -// a composite value when scanning a query result row. -// isNull is set if scanned value is NULL -// Rest of arguments are set in the order of fields in the composite -// -// Use of Scan method doesn't modify original composite -func (src CompositeType) Scan(isNull *bool, dst ...interface{}) BinaryDecoderFunc { - return func(ci *ConnInfo, buf []byte) error { - if err := src.DecodeBinary(ci, buf); err != nil { - return err - } - - if src.status == Null { - *isNull = true - return nil - } - - for i, f := range src.fields { - if err := f.AssignTo(dst[i]); err != nil { - return err - } - } - return nil - } -} - type CompositeBinaryScanner struct { rp int src []byte diff --git a/composite_type_test.go b/composite_type_test.go index 3e38b6dc..56b9318b 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -53,49 +53,144 @@ func TestCompositeTypeSetAndGet(t *testing.T) { } } -//ExampleComposite demonstrates use of Row() function to pass and receive -// back composite types without creating boilderplate custom types. +func TestCompositeTypeAssignTo(t *testing.T) { + ct := pgtype.NewCompositeType(&pgtype.Text{}, &pgtype.Int4{}) + + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var a string + var b int32 + + err = ct.AssignTo([]interface{}{&a, &b}) + assert.NoError(t, err) + + assert.Equal(t, "foo", a) + assert.Equal(t, int32(42), b) + } + + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var a pgtype.Text + var b pgtype.Int4 + + err = ct.AssignTo([]interface{}{&a, &b}) + assert.NoError(t, err) + + assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a) + assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b) + } + + // Allow nil destination component as no-op + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var b int32 + + err = ct.AssignTo([]interface{}{nil, &b}) + assert.NoError(t, err) + + assert.Equal(t, int32(42), b) + } + + // *[]interface{} dest when null + { + err := ct.Set(nil) + assert.NoError(t, err) + + var a pgtype.Text + var b pgtype.Int4 + dst := []interface{}{&a, &b} + + err = ct.AssignTo(&dst) + assert.NoError(t, err) + + assert.Nil(t, dst) + } + + // *[]interface{} dest when not null + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var a pgtype.Text + var b pgtype.Int4 + dst := []interface{}{&a, &b} + + err = ct.AssignTo(&dst) + assert.NoError(t, err) + + assert.NotNil(t, dst) + assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a) + assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b) + } +} + func Example_composite() { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - E(err) + if err != nil { + fmt.Println(err) + return + } defer conn.Close(context.Background()) - _, err = conn.Exec(context.Background(), `drop type if exists mytype; + _, err = conn.Exec(context.Background(), `drop type if exists mytype;`) + if err != nil { + fmt.Println(err) + return + } -create type mytype as ( + _, err = conn.Exec(context.Background(), `create type mytype as ( a int4, b text );`) - E(err) + if err != nil { + fmt.Println(err) + return + } defer conn.Exec(context.Background(), "drop type mytype") - qrf := pgx.QueryResultFormats{pgx.BinaryFormatCode} + var oid uint32 + err = conn.QueryRow(context.Background(), `select 'mytype'::regtype::oid`).Scan(&oid) + if err != nil { + fmt.Println(err) + return + } + + c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: c, Name: "mytype", OID: oid}) - var isNull bool var a int var b *string - c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) - c.Set([]interface{}{2, "bar"}) + err = conn.QueryRow(context.Background(), "select $1::mytype", []interface{}{2, "bar"}).Scan([]interface{}{&a, &b}) + if err != nil { + fmt.Println(err) + return + } - err = conn.QueryRow(context.Background(), "select $1::mytype", qrf, c). - Scan(c.Scan(&isNull, &a, &b)) + fmt.Printf("First: a=%d b=%s\n", a, *b) + + err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan([]interface{}{&a, &b}) + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("Second: a=%d b=%v\n", a, b) + + scanTarget := []interface{}{&a, &b} + err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(&scanTarget) E(err) - fmt.Printf("First: isNull=%v a=%d b=%s\n", isNull, a, *b) - - err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype", qrf).Scan(c.Scan(&isNull, &a, &b)) - E(err) - - fmt.Printf("Second: isNull=%v a=%d b=%v\n", isNull, a, b) - - err = conn.QueryRow(context.Background(), "select NULL::mytype", qrf).Scan(c.Scan(&isNull, &a, &b)) - E(err) - - fmt.Printf("Third: isNull=%v\n", isNull) + fmt.Printf("Third: isNull=%v\n", scanTarget == nil) // Output: - // First: isNull=false a=2 b=bar - // Second: isNull=false a=1 b= + // First: a=2 b=bar + // Second: a=1 b= // Third: isNull=true } diff --git a/pgtype.go b/pgtype.go index 193980ef..25f1a1d5 100644 --- a/pgtype.go +++ b/pgtype.go @@ -197,24 +197,6 @@ type TextEncoder interface { EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) } -//The BinaryDecoderFunc type is an adapter to allow the use of ordinary functions as BinaryDecoder types. -// If f is a function with the appropriate signature, BinaryDecoderFunc(f) is a BinaryDecoder that calls f. -type BinaryDecoderFunc func(ci *ConnInfo, src []byte) error - -// DecodeBinary calls f(ci, src) -func (f BinaryDecoderFunc) DecodeBinary(ci *ConnInfo, src []byte) error { - return f(ci, src) -} - -//The BinaryEncoderFunc type is an adapter to allow the use of ordinary functions as BinaryDecoder types. -// If f is a function with the appropriate signature, BinaryEncoderFunc(f) is a BinaryDecoder that calls f. -type BinaryEncoderFunc func(ci *ConnInfo, buf []byte) ([]byte, error) - -// EncodeBinary calls f(ci, buf) -func (f BinaryEncoderFunc) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { - return f(ci, buf) -} - var errUndefined = errors.New("cannot encode status undefined") var errBadStatus = errors.New("invalid status")