Add RegisterDefaultPgType
This allows registering a mapping of a Go type to a PostgreSQL type name. If the OID of a value to be encoded or decoded is unknown, this additional mapping will be used to determine a suitable data type.
This commit is contained in:
@@ -2,7 +2,9 @@ package pgtype
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
errors "golang.org/x/xerrors"
|
||||
)
|
||||
@@ -207,19 +209,25 @@ type DataType struct {
|
||||
type ConnInfo struct {
|
||||
oidToDataType map[uint32]*DataType
|
||||
nameToDataType map[string]*DataType
|
||||
reflectTypeToDataType map[reflect.Type]*DataType
|
||||
reflectTypeToName map[reflect.Type]string
|
||||
oidToParamFormatCode map[uint32]int16
|
||||
oidToResultFormatCode map[uint32]int16
|
||||
|
||||
reflectTypeToDataType map[reflect.Type]*DataType
|
||||
}
|
||||
|
||||
func newConnInfo() *ConnInfo {
|
||||
return &ConnInfo{
|
||||
oidToDataType: make(map[uint32]*DataType),
|
||||
nameToDataType: make(map[string]*DataType),
|
||||
reflectTypeToName: make(map[reflect.Type]string),
|
||||
oidToParamFormatCode: make(map[uint32]int16),
|
||||
oidToResultFormatCode: make(map[uint32]int16),
|
||||
}
|
||||
}
|
||||
|
||||
func NewConnInfo() *ConnInfo {
|
||||
ci := &ConnInfo{
|
||||
oidToDataType: make(map[uint32]*DataType, 128),
|
||||
nameToDataType: make(map[string]*DataType, 128),
|
||||
reflectTypeToDataType: make(map[reflect.Type]*DataType, 128),
|
||||
oidToParamFormatCode: make(map[uint32]int16, 128),
|
||||
oidToResultFormatCode: make(map[uint32]int16, 128),
|
||||
}
|
||||
ci := newConnInfo()
|
||||
|
||||
ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID})
|
||||
ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID})
|
||||
@@ -286,6 +294,42 @@ func NewConnInfo() *ConnInfo {
|
||||
ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID})
|
||||
ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID})
|
||||
|
||||
registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) {
|
||||
ci.RegisterDefaultPgType(value, name)
|
||||
valueType := reflect.TypeOf(value)
|
||||
|
||||
ci.RegisterDefaultPgType(reflect.New(valueType).Interface(), name)
|
||||
|
||||
sliceType := reflect.SliceOf(valueType)
|
||||
ci.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName)
|
||||
|
||||
ci.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName)
|
||||
}
|
||||
|
||||
// Integer types that directly map to a PostgreSQL type
|
||||
registerDefaultPgTypeVariants("int2", "_int2", int16(0))
|
||||
registerDefaultPgTypeVariants("int4", "_int4", int32(0))
|
||||
registerDefaultPgTypeVariants("int8", "_int8", int64(0))
|
||||
|
||||
// Integer types that do not have a direct match to a PostgreSQL type
|
||||
registerDefaultPgTypeVariants("int8", "_int8", uint16(0))
|
||||
registerDefaultPgTypeVariants("int8", "_int8", uint32(0))
|
||||
registerDefaultPgTypeVariants("int8", "_int8", uint64(0))
|
||||
registerDefaultPgTypeVariants("int8", "_int8", int(0))
|
||||
registerDefaultPgTypeVariants("int8", "_int8", uint(0))
|
||||
|
||||
registerDefaultPgTypeVariants("float4", "_float4", float32(0))
|
||||
registerDefaultPgTypeVariants("float8", "_float8", float64(0))
|
||||
|
||||
registerDefaultPgTypeVariants("bool", "_bool", false)
|
||||
registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{})
|
||||
registerDefaultPgTypeVariants("text", "_text", "")
|
||||
registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil))
|
||||
|
||||
registerDefaultPgTypeVariants("inet", "_inet", net.IP{})
|
||||
ci.RegisterDefaultPgType((*net.IPNet)(nil), "cidr")
|
||||
ci.RegisterDefaultPgType([]*net.IPNet(nil), "_cidr")
|
||||
|
||||
return ci
|
||||
}
|
||||
|
||||
@@ -302,16 +346,12 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) {
|
||||
}
|
||||
|
||||
func (ci *ConnInfo) RegisterDataType(t DataType) {
|
||||
tv, _ := t.Value.(TypeValue)
|
||||
if tv != nil {
|
||||
if tv, ok := t.Value.(TypeValue); ok {
|
||||
t.Value = tv.CloneTypeValue()
|
||||
}
|
||||
|
||||
ci.oidToDataType[t.OID] = &t
|
||||
ci.nameToDataType[t.Name] = &t
|
||||
if tv == nil {
|
||||
ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t
|
||||
}
|
||||
|
||||
{
|
||||
var formatCode int16
|
||||
@@ -336,6 +376,16 @@ func (ci *ConnInfo) RegisterDataType(t DataType) {
|
||||
if d, ok := t.Value.(BinaryDecoder); ok {
|
||||
t.binaryDecoder = d
|
||||
}
|
||||
|
||||
ci.reflectTypeToDataType = nil // Invalidated by type registration
|
||||
}
|
||||
|
||||
// RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be
|
||||
// encoded or decoded is determined by the PostgreSQL OID. But if the OID of a value to be encoded or decoded is
|
||||
// unknown, this additional mapping will be used by DataTypeForValue to determine a suitable data type.
|
||||
func (ci *ConnInfo) RegisterDefaultPgType(value interface{}, name string) {
|
||||
ci.reflectTypeToName[reflect.TypeOf(value)] = name
|
||||
ci.reflectTypeToDataType = nil // Invalidated by registering a default type
|
||||
}
|
||||
|
||||
func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) {
|
||||
@@ -348,13 +398,35 @@ func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) {
|
||||
return dt, ok
|
||||
}
|
||||
|
||||
func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) {
|
||||
func (ci *ConnInfo) buildReflectTypeToDataType() {
|
||||
ci.reflectTypeToDataType = make(map[reflect.Type]*DataType)
|
||||
|
||||
for _, dt := range ci.oidToDataType {
|
||||
if _, is := dt.Value.(TypeValue); !is {
|
||||
ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt
|
||||
}
|
||||
}
|
||||
|
||||
for reflectType, name := range ci.reflectTypeToName {
|
||||
if dt, ok := ci.nameToDataType[name]; ok {
|
||||
ci.reflectTypeToDataType[reflectType] = dt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DataTypeForValue finds a data type suitable for v. Use RegisterDataType to register types that can encode and decode
|
||||
// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type.
|
||||
func (ci *ConnInfo) DataTypeForValue(v interface{}) (*DataType, bool) {
|
||||
if ci.reflectTypeToDataType == nil {
|
||||
ci.buildReflectTypeToDataType()
|
||||
}
|
||||
|
||||
if tv, ok := v.(TypeValue); ok {
|
||||
dt, ok := ci.nameToDataType[tv.PgTypeName()]
|
||||
return dt, ok
|
||||
}
|
||||
|
||||
dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()]
|
||||
dt, ok := ci.reflectTypeToDataType[reflect.TypeOf(v)]
|
||||
return dt, ok
|
||||
}
|
||||
|
||||
@@ -376,13 +448,7 @@ func (ci *ConnInfo) ResultFormatCodeForOID(oid uint32) int16 {
|
||||
|
||||
// DeepCopy makes a deep copy of the ConnInfo.
|
||||
func (ci *ConnInfo) DeepCopy() *ConnInfo {
|
||||
ci2 := &ConnInfo{
|
||||
oidToDataType: make(map[uint32]*DataType, len(ci.oidToDataType)),
|
||||
nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)),
|
||||
reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)),
|
||||
oidToParamFormatCode: make(map[uint32]int16, len(ci.oidToParamFormatCode)),
|
||||
oidToResultFormatCode: make(map[uint32]int16, len(ci.oidToResultFormatCode)),
|
||||
}
|
||||
ci2 := newConnInfo()
|
||||
|
||||
for _, dt := range ci.oidToDataType {
|
||||
var value Value
|
||||
@@ -399,6 +465,10 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo {
|
||||
})
|
||||
}
|
||||
|
||||
for t, n := range ci.reflectTypeToName {
|
||||
ci2.reflectTypeToName[t] = n
|
||||
}
|
||||
|
||||
return ci2
|
||||
}
|
||||
|
||||
@@ -416,7 +486,19 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac
|
||||
return errors.Errorf("unknown format code: %v", formatCode)
|
||||
}
|
||||
|
||||
if dt, ok := ci.DataTypeForOID(oid); ok {
|
||||
var dt *DataType
|
||||
|
||||
if oid == 0 {
|
||||
if dataType, ok := ci.DataTypeForValue(dest); ok {
|
||||
dt = dataType
|
||||
}
|
||||
} else {
|
||||
if dataType, ok := ci.DataTypeForOID(oid); ok {
|
||||
dt = dataType
|
||||
}
|
||||
}
|
||||
|
||||
if dt != nil {
|
||||
switch formatCode {
|
||||
case BinaryFormatCode:
|
||||
if dt.binaryDecoder != nil {
|
||||
|
||||
+14
-5
@@ -104,30 +104,39 @@ func (ct *pgCustomType) DecodeText(ci *pgtype.ConnInfo, buf []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConnInfoScanUnknownOIDToCustomType(t *testing.T) {
|
||||
unknownOID := uint32(999999)
|
||||
func TestConnInfoScanUnregisteredOIDToCustomType(t *testing.T) {
|
||||
unregisteredOID := uint32(999999)
|
||||
ci := pgtype.NewConnInfo()
|
||||
|
||||
var ct pgCustomType
|
||||
err := ci.Scan(unknownOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct)
|
||||
err := ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "foo", ct.a)
|
||||
assert.Equal(t, "bar", ct.b)
|
||||
|
||||
// Scan value into pointer to custom type
|
||||
var pCt *pgCustomType
|
||||
err = ci.Scan(unknownOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt)
|
||||
err = ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt)
|
||||
assert.NoError(t, err)
|
||||
require.NotNil(t, pCt)
|
||||
assert.Equal(t, "foo", pCt.a)
|
||||
assert.Equal(t, "bar", pCt.b)
|
||||
|
||||
// Scan null into pointer to custom type
|
||||
err = ci.Scan(unknownOID, pgx.TextFormatCode, nil, &pCt)
|
||||
err = ci.Scan(unregisteredOID, pgx.TextFormatCode, nil, &pCt)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, pCt)
|
||||
}
|
||||
|
||||
func TestConnInfoScanUnknownOIDTextFormat(t *testing.T) {
|
||||
ci := pgtype.NewConnInfo()
|
||||
|
||||
var n int32
|
||||
err := ci.Scan(0, pgx.TextFormatCode, []byte("123"), &n)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 123, n)
|
||||
}
|
||||
|
||||
func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) {
|
||||
ci := pgtype.NewConnInfo()
|
||||
src := []byte{0, 0, 0, 42}
|
||||
|
||||
Reference in New Issue
Block a user