2
0

Add RegisterDefaultPgType

This allows registering a mapping of a Go type to a PostgreSQL type
name. If the OID of a value to be encoded or decoded is unknown, this
additional mapping will be used to determine a suitable data type.
This commit is contained in:
Jack Christensen
2020-05-08 16:04:16 -05:00
parent 4a50a63f12
commit 97bbe6ae20
2 changed files with 119 additions and 28 deletions
+105 -23
View File
@@ -2,7 +2,9 @@ package pgtype
import (
"database/sql"
"net"
"reflect"
"time"
errors "golang.org/x/xerrors"
)
@@ -207,19 +209,25 @@ type DataType struct {
type ConnInfo struct {
oidToDataType map[uint32]*DataType
nameToDataType map[string]*DataType
reflectTypeToDataType map[reflect.Type]*DataType
reflectTypeToName map[reflect.Type]string
oidToParamFormatCode map[uint32]int16
oidToResultFormatCode map[uint32]int16
reflectTypeToDataType map[reflect.Type]*DataType
}
func newConnInfo() *ConnInfo {
return &ConnInfo{
oidToDataType: make(map[uint32]*DataType),
nameToDataType: make(map[string]*DataType),
reflectTypeToName: make(map[reflect.Type]string),
oidToParamFormatCode: make(map[uint32]int16),
oidToResultFormatCode: make(map[uint32]int16),
}
}
func NewConnInfo() *ConnInfo {
ci := &ConnInfo{
oidToDataType: make(map[uint32]*DataType, 128),
nameToDataType: make(map[string]*DataType, 128),
reflectTypeToDataType: make(map[reflect.Type]*DataType, 128),
oidToParamFormatCode: make(map[uint32]int16, 128),
oidToResultFormatCode: make(map[uint32]int16, 128),
}
ci := newConnInfo()
ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID})
ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID})
@@ -286,6 +294,42 @@ func NewConnInfo() *ConnInfo {
ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID})
ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID})
registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) {
ci.RegisterDefaultPgType(value, name)
valueType := reflect.TypeOf(value)
ci.RegisterDefaultPgType(reflect.New(valueType).Interface(), name)
sliceType := reflect.SliceOf(valueType)
ci.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName)
ci.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName)
}
// Integer types that directly map to a PostgreSQL type
registerDefaultPgTypeVariants("int2", "_int2", int16(0))
registerDefaultPgTypeVariants("int4", "_int4", int32(0))
registerDefaultPgTypeVariants("int8", "_int8", int64(0))
// Integer types that do not have a direct match to a PostgreSQL type
registerDefaultPgTypeVariants("int8", "_int8", uint16(0))
registerDefaultPgTypeVariants("int8", "_int8", uint32(0))
registerDefaultPgTypeVariants("int8", "_int8", uint64(0))
registerDefaultPgTypeVariants("int8", "_int8", int(0))
registerDefaultPgTypeVariants("int8", "_int8", uint(0))
registerDefaultPgTypeVariants("float4", "_float4", float32(0))
registerDefaultPgTypeVariants("float8", "_float8", float64(0))
registerDefaultPgTypeVariants("bool", "_bool", false)
registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{})
registerDefaultPgTypeVariants("text", "_text", "")
registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil))
registerDefaultPgTypeVariants("inet", "_inet", net.IP{})
ci.RegisterDefaultPgType((*net.IPNet)(nil), "cidr")
ci.RegisterDefaultPgType([]*net.IPNet(nil), "_cidr")
return ci
}
@@ -302,16 +346,12 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) {
}
func (ci *ConnInfo) RegisterDataType(t DataType) {
tv, _ := t.Value.(TypeValue)
if tv != nil {
if tv, ok := t.Value.(TypeValue); ok {
t.Value = tv.CloneTypeValue()
}
ci.oidToDataType[t.OID] = &t
ci.nameToDataType[t.Name] = &t
if tv == nil {
ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t
}
{
var formatCode int16
@@ -336,6 +376,16 @@ func (ci *ConnInfo) RegisterDataType(t DataType) {
if d, ok := t.Value.(BinaryDecoder); ok {
t.binaryDecoder = d
}
ci.reflectTypeToDataType = nil // Invalidated by type registration
}
// RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be
// encoded or decoded is determined by the PostgreSQL OID. But if the OID of a value to be encoded or decoded is
// unknown, this additional mapping will be used by DataTypeForValue to determine a suitable data type.
func (ci *ConnInfo) RegisterDefaultPgType(value interface{}, name string) {
ci.reflectTypeToName[reflect.TypeOf(value)] = name
ci.reflectTypeToDataType = nil // Invalidated by registering a default type
}
func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) {
@@ -348,13 +398,35 @@ func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) {
return dt, ok
}
func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) {
func (ci *ConnInfo) buildReflectTypeToDataType() {
ci.reflectTypeToDataType = make(map[reflect.Type]*DataType)
for _, dt := range ci.oidToDataType {
if _, is := dt.Value.(TypeValue); !is {
ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt
}
}
for reflectType, name := range ci.reflectTypeToName {
if dt, ok := ci.nameToDataType[name]; ok {
ci.reflectTypeToDataType[reflectType] = dt
}
}
}
// DataTypeForValue finds a data type suitable for v. Use RegisterDataType to register types that can encode and decode
// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type.
func (ci *ConnInfo) DataTypeForValue(v interface{}) (*DataType, bool) {
if ci.reflectTypeToDataType == nil {
ci.buildReflectTypeToDataType()
}
if tv, ok := v.(TypeValue); ok {
dt, ok := ci.nameToDataType[tv.PgTypeName()]
return dt, ok
}
dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()]
dt, ok := ci.reflectTypeToDataType[reflect.TypeOf(v)]
return dt, ok
}
@@ -376,13 +448,7 @@ func (ci *ConnInfo) ResultFormatCodeForOID(oid uint32) int16 {
// DeepCopy makes a deep copy of the ConnInfo.
func (ci *ConnInfo) DeepCopy() *ConnInfo {
ci2 := &ConnInfo{
oidToDataType: make(map[uint32]*DataType, len(ci.oidToDataType)),
nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)),
reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)),
oidToParamFormatCode: make(map[uint32]int16, len(ci.oidToParamFormatCode)),
oidToResultFormatCode: make(map[uint32]int16, len(ci.oidToResultFormatCode)),
}
ci2 := newConnInfo()
for _, dt := range ci.oidToDataType {
var value Value
@@ -399,6 +465,10 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo {
})
}
for t, n := range ci.reflectTypeToName {
ci2.reflectTypeToName[t] = n
}
return ci2
}
@@ -416,7 +486,19 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac
return errors.Errorf("unknown format code: %v", formatCode)
}
if dt, ok := ci.DataTypeForOID(oid); ok {
var dt *DataType
if oid == 0 {
if dataType, ok := ci.DataTypeForValue(dest); ok {
dt = dataType
}
} else {
if dataType, ok := ci.DataTypeForOID(oid); ok {
dt = dataType
}
}
if dt != nil {
switch formatCode {
case BinaryFormatCode:
if dt.binaryDecoder != nil {
+14 -5
View File
@@ -104,30 +104,39 @@ func (ct *pgCustomType) DecodeText(ci *pgtype.ConnInfo, buf []byte) error {
return nil
}
func TestConnInfoScanUnknownOIDToCustomType(t *testing.T) {
unknownOID := uint32(999999)
func TestConnInfoScanUnregisteredOIDToCustomType(t *testing.T) {
unregisteredOID := uint32(999999)
ci := pgtype.NewConnInfo()
var ct pgCustomType
err := ci.Scan(unknownOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct)
err := ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct)
assert.NoError(t, err)
assert.Equal(t, "foo", ct.a)
assert.Equal(t, "bar", ct.b)
// Scan value into pointer to custom type
var pCt *pgCustomType
err = ci.Scan(unknownOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt)
err = ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt)
assert.NoError(t, err)
require.NotNil(t, pCt)
assert.Equal(t, "foo", pCt.a)
assert.Equal(t, "bar", pCt.b)
// Scan null into pointer to custom type
err = ci.Scan(unknownOID, pgx.TextFormatCode, nil, &pCt)
err = ci.Scan(unregisteredOID, pgx.TextFormatCode, nil, &pCt)
assert.NoError(t, err)
assert.Nil(t, pCt)
}
func TestConnInfoScanUnknownOIDTextFormat(t *testing.T) {
ci := pgtype.NewConnInfo()
var n int32
err := ci.Scan(0, pgx.TextFormatCode, []byte("123"), &n)
assert.NoError(t, err)
assert.EqualValues(t, 123, n)
}
func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) {
ci := pgtype.NewConnInfo()
src := []byte{0, 0, 0, 42}