diff --git a/composite_bench_test.go b/composite_bench_test.go index 9eaf7632..e1dd6d04 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -100,7 +100,7 @@ func BenchmarkBinaryEncodingComposite(b *testing.B) { ci := pgtype.NewConnInfo() f1 := 2 f2 := ptrS("bar") - c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) + c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{}) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -163,7 +163,7 @@ func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { var f1 int var f2 *string - c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) + c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{}) b.ResetTimer() for n := 0; n < b.N; n++ { diff --git a/composite_type.go b/composite_type.go index 53386f37..03d88aea 100644 --- a/composite_type.go +++ b/composite_type.go @@ -8,8 +8,10 @@ import ( ) type CompositeType struct { - fields []Value status Status + + typeName string + fields []Value } // NewCompositeType creates a Composite object, which acts as a "schema" for @@ -19,8 +21,8 @@ type CompositeType struct { // SetFields method // To read composite fields back pass result of Scan() method // to query Scan function. -func NewCompositeType(fields ...Value) *CompositeType { - return &CompositeType{fields, Undefined} +func NewCompositeType(typeName string, fields ...Value) *CompositeType { + return &CompositeType{typeName: typeName, fields: fields} } func (src CompositeType) Get() interface{} { @@ -38,6 +40,23 @@ func (src CompositeType) Get() interface{} { } } +func (ct *CompositeType) NewTypeValue() Value { + a := &CompositeType{ + typeName: ct.typeName, + fields: make([]Value, len(ct.fields)), + } + + for i := range ct.fields { + a.fields[i] = NewValue(ct.fields[i]) + } + + return a +} + +func (ct *CompositeType) TypeName() string { + return ct.typeName +} + func (dst *CompositeType) Set(src interface{}) error { if src == nil { dst.status = Null diff --git a/composite_type_test.go b/composite_type_test.go index 56b9318b..92ecc849 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -12,7 +12,7 @@ import ( ) func TestCompositeTypeSetAndGet(t *testing.T) { - ct := pgtype.NewCompositeType(&pgtype.Text{}, &pgtype.Int4{}) + ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{}) assert.Equal(t, pgtype.Undefined, ct.Get()) nilTests := []struct { @@ -54,7 +54,7 @@ func TestCompositeTypeSetAndGet(t *testing.T) { } func TestCompositeTypeAssignTo(t *testing.T) { - ct := pgtype.NewCompositeType(&pgtype.Text{}, &pgtype.Int4{}) + ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{}) { err := ct.Set([]interface{}{"foo", int32(42)}) @@ -161,7 +161,7 @@ func Example_composite() { return } - c := pgtype.NewCompositeType(&pgtype.Int4{}, &pgtype.Text{}) + c := pgtype.NewCompositeType("mytype", &pgtype.Int4{}, &pgtype.Text{}) conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: c, Name: "mytype", OID: oid}) var a int diff --git a/pgtype.go b/pgtype.go index 5662f4c7..091e98c4 100644 --- a/pgtype.go +++ b/pgtype.go @@ -360,9 +360,7 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) { } func (ci *ConnInfo) RegisterDataType(t DataType) { - if tv, ok := t.Value.(TypeValue); ok { - t.Value = tv.NewTypeValue() - } + t.Value = NewValue(t.Value) ci.oidToDataType[t.OID] = &t ci.nameToDataType[t.Name] = &t @@ -469,15 +467,8 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { ci2 := newConnInfo() for _, dt := range ci.oidToDataType { - var value Value - if tv, ok := dt.Value.(TypeValue); ok { - value = tv.NewTypeValue() - } else { - value = reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value) - } - ci2.RegisterDataType(DataType{ - Value: value, + Value: NewValue(dt.Value), Name: dt.Name, OID: dt.OID, }) @@ -844,6 +835,15 @@ func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) } } +// NewValue returns a new instance of the same type as v. +func NewValue(v Value) Value { + if tv, ok := v.(TypeValue); ok { + return tv.NewTypeValue() + } else { + return reflect.New(reflect.ValueOf(v).Elem().Type()).Interface().(Value) + } +} + var nameValues map[string]Value func init() {