diff --git a/pgtype.go b/pgtype.go index 71babbc8..af8d8661 100644 --- a/pgtype.go +++ b/pgtype.go @@ -2,7 +2,9 @@ package pgtype import ( "database/sql" + "net" "reflect" + "time" errors "golang.org/x/xerrors" ) @@ -207,19 +209,25 @@ type DataType struct { type ConnInfo struct { oidToDataType map[uint32]*DataType nameToDataType map[string]*DataType - reflectTypeToDataType map[reflect.Type]*DataType + reflectTypeToName map[reflect.Type]string oidToParamFormatCode map[uint32]int16 oidToResultFormatCode map[uint32]int16 + + reflectTypeToDataType map[reflect.Type]*DataType +} + +func newConnInfo() *ConnInfo { + return &ConnInfo{ + oidToDataType: make(map[uint32]*DataType), + nameToDataType: make(map[string]*DataType), + reflectTypeToName: make(map[reflect.Type]string), + oidToParamFormatCode: make(map[uint32]int16), + oidToResultFormatCode: make(map[uint32]int16), + } } func NewConnInfo() *ConnInfo { - ci := &ConnInfo{ - oidToDataType: make(map[uint32]*DataType, 128), - nameToDataType: make(map[string]*DataType, 128), - reflectTypeToDataType: make(map[reflect.Type]*DataType, 128), - oidToParamFormatCode: make(map[uint32]int16, 128), - oidToResultFormatCode: make(map[uint32]int16, 128), - } + ci := newConnInfo() ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID}) ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID}) @@ -286,6 +294,42 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID}) ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID}) + registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { + ci.RegisterDefaultPgType(value, name) + valueType := reflect.TypeOf(value) + + ci.RegisterDefaultPgType(reflect.New(valueType).Interface(), name) + + sliceType := reflect.SliceOf(valueType) + ci.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName) + + ci.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName) + } + + // Integer types that directly map to a PostgreSQL type + registerDefaultPgTypeVariants("int2", "_int2", int16(0)) + registerDefaultPgTypeVariants("int4", "_int4", int32(0)) + registerDefaultPgTypeVariants("int8", "_int8", int64(0)) + + // Integer types that do not have a direct match to a PostgreSQL type + registerDefaultPgTypeVariants("int8", "_int8", uint16(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint32(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint64(0)) + registerDefaultPgTypeVariants("int8", "_int8", int(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint(0)) + + registerDefaultPgTypeVariants("float4", "_float4", float32(0)) + registerDefaultPgTypeVariants("float8", "_float8", float64(0)) + + registerDefaultPgTypeVariants("bool", "_bool", false) + registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{}) + registerDefaultPgTypeVariants("text", "_text", "") + registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil)) + + registerDefaultPgTypeVariants("inet", "_inet", net.IP{}) + ci.RegisterDefaultPgType((*net.IPNet)(nil), "cidr") + ci.RegisterDefaultPgType([]*net.IPNet(nil), "_cidr") + return ci } @@ -302,16 +346,12 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) { } func (ci *ConnInfo) RegisterDataType(t DataType) { - tv, _ := t.Value.(TypeValue) - if tv != nil { + if tv, ok := t.Value.(TypeValue); ok { t.Value = tv.CloneTypeValue() } ci.oidToDataType[t.OID] = &t ci.nameToDataType[t.Name] = &t - if tv == nil { - ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t - } { var formatCode int16 @@ -336,6 +376,16 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { if d, ok := t.Value.(BinaryDecoder); ok { t.binaryDecoder = d } + + ci.reflectTypeToDataType = nil // Invalidated by type registration +} + +// RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be +// encoded or decoded is determined by the PostgreSQL OID. But if the OID of a value to be encoded or decoded is +// unknown, this additional mapping will be used by DataTypeForValue to determine a suitable data type. +func (ci *ConnInfo) RegisterDefaultPgType(value interface{}, name string) { + ci.reflectTypeToName[reflect.TypeOf(value)] = name + ci.reflectTypeToDataType = nil // Invalidated by registering a default type } func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) { @@ -348,13 +398,35 @@ func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { return dt, ok } -func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { +func (ci *ConnInfo) buildReflectTypeToDataType() { + ci.reflectTypeToDataType = make(map[reflect.Type]*DataType) + + for _, dt := range ci.oidToDataType { + if _, is := dt.Value.(TypeValue); !is { + ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt + } + } + + for reflectType, name := range ci.reflectTypeToName { + if dt, ok := ci.nameToDataType[name]; ok { + ci.reflectTypeToDataType[reflectType] = dt + } + } +} + +// DataTypeForValue finds a data type suitable for v. Use RegisterDataType to register types that can encode and decode +// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. +func (ci *ConnInfo) DataTypeForValue(v interface{}) (*DataType, bool) { + if ci.reflectTypeToDataType == nil { + ci.buildReflectTypeToDataType() + } + if tv, ok := v.(TypeValue); ok { dt, ok := ci.nameToDataType[tv.PgTypeName()] return dt, ok } - dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()] + dt, ok := ci.reflectTypeToDataType[reflect.TypeOf(v)] return dt, ok } @@ -376,13 +448,7 @@ func (ci *ConnInfo) ResultFormatCodeForOID(oid uint32) int16 { // DeepCopy makes a deep copy of the ConnInfo. func (ci *ConnInfo) DeepCopy() *ConnInfo { - ci2 := &ConnInfo{ - oidToDataType: make(map[uint32]*DataType, len(ci.oidToDataType)), - nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), - reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), - oidToParamFormatCode: make(map[uint32]int16, len(ci.oidToParamFormatCode)), - oidToResultFormatCode: make(map[uint32]int16, len(ci.oidToResultFormatCode)), - } + ci2 := newConnInfo() for _, dt := range ci.oidToDataType { var value Value @@ -399,6 +465,10 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { }) } + for t, n := range ci.reflectTypeToName { + ci2.reflectTypeToName[t] = n + } + return ci2 } @@ -416,7 +486,19 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac return errors.Errorf("unknown format code: %v", formatCode) } - if dt, ok := ci.DataTypeForOID(oid); ok { + var dt *DataType + + if oid == 0 { + if dataType, ok := ci.DataTypeForValue(dest); ok { + dt = dataType + } + } else { + if dataType, ok := ci.DataTypeForOID(oid); ok { + dt = dataType + } + } + + if dt != nil { switch formatCode { case BinaryFormatCode: if dt.binaryDecoder != nil { diff --git a/pgtype_test.go b/pgtype_test.go index dee5377d..664c5394 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -104,30 +104,39 @@ func (ct *pgCustomType) DecodeText(ci *pgtype.ConnInfo, buf []byte) error { return nil } -func TestConnInfoScanUnknownOIDToCustomType(t *testing.T) { - unknownOID := uint32(999999) +func TestConnInfoScanUnregisteredOIDToCustomType(t *testing.T) { + unregisteredOID := uint32(999999) ci := pgtype.NewConnInfo() var ct pgCustomType - err := ci.Scan(unknownOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct) + err := ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct) assert.NoError(t, err) assert.Equal(t, "foo", ct.a) assert.Equal(t, "bar", ct.b) // Scan value into pointer to custom type var pCt *pgCustomType - err = ci.Scan(unknownOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt) + err = ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt) assert.NoError(t, err) require.NotNil(t, pCt) assert.Equal(t, "foo", pCt.a) assert.Equal(t, "bar", pCt.b) // Scan null into pointer to custom type - err = ci.Scan(unknownOID, pgx.TextFormatCode, nil, &pCt) + err = ci.Scan(unregisteredOID, pgx.TextFormatCode, nil, &pCt) assert.NoError(t, err) assert.Nil(t, pCt) } +func TestConnInfoScanUnknownOIDTextFormat(t *testing.T) { + ci := pgtype.NewConnInfo() + + var n int32 + err := ci.Scan(0, pgx.TextFormatCode, []byte("123"), &n) + assert.NoError(t, err) + assert.EqualValues(t, 123, n) +} + func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) { ci := pgtype.NewConnInfo() src := []byte{0, 0, 0, 42}