diff --git a/composite_bench_test.go b/composite_bench_test.go index 1a5a7492..d4eb0ac7 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -104,7 +104,7 @@ func BenchmarkBinaryEncodingComposite(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - c.SetFields(f1, f2) + c.Set([]interface{}{f1, f2}) buf, _ = c.EncodeBinary(ci, buf[:0]) } x = buf diff --git a/composite_type.go b/composite_type.go index 7cc620d5..76b32b86 100644 --- a/composite_type.go +++ b/composite_type.go @@ -20,13 +20,17 @@ type CompositeType struct { // To read composite fields back pass result of Scan() method // to query Scan function. func NewCompositeType(fields ...Value) *CompositeType { - return &CompositeType{fields, Present} + return &CompositeType{fields, Undefined} } func (src CompositeType) Get() interface{} { switch src.status { case Present: - return src + results := make([]interface{}, len(src.fields)) + for i := range results { + results[i] = src.fields[i].Get() + } + return results case Null: return nil default: @@ -34,17 +38,16 @@ func (src CompositeType) Get() interface{} { } } -// Set is called internally when passing query arguments. func (dst *CompositeType) Set(src interface{}) error { if src == nil { - *dst = CompositeType{status: Null} + dst.status = Null return nil } switch value := src.(type) { - case []Value: + case []interface{}: if len(value) != len(dst.fields) { - return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields)) + return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.fields)) } for i, v := range value { if err := dst.fields[i].Set(v); err != nil { @@ -52,6 +55,12 @@ func (dst *CompositeType) Set(src interface{}) error { } } dst.status = Present + case *[]interface{}: + if value == nil { + dst.status = Null + return nil + } + return dst.Set(*value) default: return errors.Errorf("Can not convert %v to Composite", src) } @@ -138,20 +147,6 @@ func (src CompositeType) Scan(isNull *bool, dst ...interface{}) BinaryDecoderFun } } -// SetFields sets Composite's fields to corresponding values -func (dst *CompositeType) SetFields(values ...interface{}) error { - if len(values) != len(dst.fields) { - return errors.Errorf("Number of fields don't match. Composite has %d fields", len(dst.fields)) - } - for i, v := range values { - if err := dst.fields[i].Set(v); err != nil { - return err - } - } - dst.status = Present - return nil -} - type CompositeBinaryScanner struct { rp int src []byte diff --git a/composite_type_test.go b/composite_type_test.go index 4f614fc5..3e38b6dc 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -4,11 +4,55 @@ import ( "context" "fmt" "os" + "testing" "github.com/jackc/pgtype" pgx "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" ) +func TestCompositeTypeSetAndGet(t *testing.T) { + ct := pgtype.NewCompositeType(&pgtype.Text{}, &pgtype.Int4{}) + assert.Equal(t, pgtype.Undefined, ct.Get()) + + nilTests := []struct { + src interface{} + }{ + {nil}, // nil interface + {(*[]interface{})(nil)}, // typed nil + } + + for i, tt := range nilTests { + err := ct.Set(tt.src) + assert.NoErrorf(t, err, "%d", i) + assert.Equal(t, nil, ct.Get()) + } + + compatibleValuesTests := []struct { + src []interface{} + expected []interface{} + }{ + { + src: []interface{}{"foo", int32(42)}, + expected: []interface{}{"foo", int32(42)}, + }, + { + src: []interface{}{nil, nil}, + expected: []interface{}{nil, nil}, + }, + { + src: []interface{}{&pgtype.Text{String: "hi", Status: pgtype.Present}, &pgtype.Int4{Int: 7, Status: pgtype.Present}}, + expected: []interface{}{"hi", int32(7)}, + }, + } + + for i, tt := range compatibleValuesTests { + err := ct.Set(tt.src) + assert.NoErrorf(t, err, "%d", i) + assert.EqualValues(t, tt.expected, ct.Get()) + } +} + //ExampleComposite demonstrates use of Row() function to pass and receive // back composite types without creating boilderplate custom types. func Example_composite() { @@ -32,7 +76,7 @@ create type mytype as ( var b *string c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) - c.SetFields(2, "bar") + c.Set([]interface{}{2, "bar"}) err = conn.QueryRow(context.Background(), "select $1::mytype", qrf, c). Scan(c.Scan(&isNull, &a, &b))