Add RegisterDefaultPgType
This allows registering a mapping of a Go type to a PostgreSQL type name. If the OID of a value to be encoded or decoded is unknown, this additional mapping will be used to determine a suitable data type.
This commit is contained in:
@@ -2,7 +2,9 @@ package pgtype
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
errors "golang.org/x/xerrors"
|
errors "golang.org/x/xerrors"
|
||||||
)
|
)
|
||||||
@@ -207,19 +209,25 @@ type DataType struct {
|
|||||||
type ConnInfo struct {
|
type ConnInfo struct {
|
||||||
oidToDataType map[uint32]*DataType
|
oidToDataType map[uint32]*DataType
|
||||||
nameToDataType map[string]*DataType
|
nameToDataType map[string]*DataType
|
||||||
reflectTypeToDataType map[reflect.Type]*DataType
|
reflectTypeToName map[reflect.Type]string
|
||||||
oidToParamFormatCode map[uint32]int16
|
oidToParamFormatCode map[uint32]int16
|
||||||
oidToResultFormatCode map[uint32]int16
|
oidToResultFormatCode map[uint32]int16
|
||||||
|
|
||||||
|
reflectTypeToDataType map[reflect.Type]*DataType
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConnInfo() *ConnInfo {
|
||||||
|
return &ConnInfo{
|
||||||
|
oidToDataType: make(map[uint32]*DataType),
|
||||||
|
nameToDataType: make(map[string]*DataType),
|
||||||
|
reflectTypeToName: make(map[reflect.Type]string),
|
||||||
|
oidToParamFormatCode: make(map[uint32]int16),
|
||||||
|
oidToResultFormatCode: make(map[uint32]int16),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnInfo() *ConnInfo {
|
func NewConnInfo() *ConnInfo {
|
||||||
ci := &ConnInfo{
|
ci := newConnInfo()
|
||||||
oidToDataType: make(map[uint32]*DataType, 128),
|
|
||||||
nameToDataType: make(map[string]*DataType, 128),
|
|
||||||
reflectTypeToDataType: make(map[reflect.Type]*DataType, 128),
|
|
||||||
oidToParamFormatCode: make(map[uint32]int16, 128),
|
|
||||||
oidToResultFormatCode: make(map[uint32]int16, 128),
|
|
||||||
}
|
|
||||||
|
|
||||||
ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID})
|
ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID})
|
||||||
ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID})
|
ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID})
|
||||||
@@ -286,6 +294,42 @@ func NewConnInfo() *ConnInfo {
|
|||||||
ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID})
|
ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID})
|
||||||
ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID})
|
ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID})
|
||||||
|
|
||||||
|
registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) {
|
||||||
|
ci.RegisterDefaultPgType(value, name)
|
||||||
|
valueType := reflect.TypeOf(value)
|
||||||
|
|
||||||
|
ci.RegisterDefaultPgType(reflect.New(valueType).Interface(), name)
|
||||||
|
|
||||||
|
sliceType := reflect.SliceOf(valueType)
|
||||||
|
ci.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName)
|
||||||
|
|
||||||
|
ci.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Integer types that directly map to a PostgreSQL type
|
||||||
|
registerDefaultPgTypeVariants("int2", "_int2", int16(0))
|
||||||
|
registerDefaultPgTypeVariants("int4", "_int4", int32(0))
|
||||||
|
registerDefaultPgTypeVariants("int8", "_int8", int64(0))
|
||||||
|
|
||||||
|
// Integer types that do not have a direct match to a PostgreSQL type
|
||||||
|
registerDefaultPgTypeVariants("int8", "_int8", uint16(0))
|
||||||
|
registerDefaultPgTypeVariants("int8", "_int8", uint32(0))
|
||||||
|
registerDefaultPgTypeVariants("int8", "_int8", uint64(0))
|
||||||
|
registerDefaultPgTypeVariants("int8", "_int8", int(0))
|
||||||
|
registerDefaultPgTypeVariants("int8", "_int8", uint(0))
|
||||||
|
|
||||||
|
registerDefaultPgTypeVariants("float4", "_float4", float32(0))
|
||||||
|
registerDefaultPgTypeVariants("float8", "_float8", float64(0))
|
||||||
|
|
||||||
|
registerDefaultPgTypeVariants("bool", "_bool", false)
|
||||||
|
registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{})
|
||||||
|
registerDefaultPgTypeVariants("text", "_text", "")
|
||||||
|
registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil))
|
||||||
|
|
||||||
|
registerDefaultPgTypeVariants("inet", "_inet", net.IP{})
|
||||||
|
ci.RegisterDefaultPgType((*net.IPNet)(nil), "cidr")
|
||||||
|
ci.RegisterDefaultPgType([]*net.IPNet(nil), "_cidr")
|
||||||
|
|
||||||
return ci
|
return ci
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,16 +346,12 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ci *ConnInfo) RegisterDataType(t DataType) {
|
func (ci *ConnInfo) RegisterDataType(t DataType) {
|
||||||
tv, _ := t.Value.(TypeValue)
|
if tv, ok := t.Value.(TypeValue); ok {
|
||||||
if tv != nil {
|
|
||||||
t.Value = tv.CloneTypeValue()
|
t.Value = tv.CloneTypeValue()
|
||||||
}
|
}
|
||||||
|
|
||||||
ci.oidToDataType[t.OID] = &t
|
ci.oidToDataType[t.OID] = &t
|
||||||
ci.nameToDataType[t.Name] = &t
|
ci.nameToDataType[t.Name] = &t
|
||||||
if tv == nil {
|
|
||||||
ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
var formatCode int16
|
var formatCode int16
|
||||||
@@ -336,6 +376,16 @@ func (ci *ConnInfo) RegisterDataType(t DataType) {
|
|||||||
if d, ok := t.Value.(BinaryDecoder); ok {
|
if d, ok := t.Value.(BinaryDecoder); ok {
|
||||||
t.binaryDecoder = d
|
t.binaryDecoder = d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ci.reflectTypeToDataType = nil // Invalidated by type registration
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be
|
||||||
|
// encoded or decoded is determined by the PostgreSQL OID. But if the OID of a value to be encoded or decoded is
|
||||||
|
// unknown, this additional mapping will be used by DataTypeForValue to determine a suitable data type.
|
||||||
|
func (ci *ConnInfo) RegisterDefaultPgType(value interface{}, name string) {
|
||||||
|
ci.reflectTypeToName[reflect.TypeOf(value)] = name
|
||||||
|
ci.reflectTypeToDataType = nil // Invalidated by registering a default type
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) {
|
func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) {
|
||||||
@@ -348,13 +398,35 @@ func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) {
|
|||||||
return dt, ok
|
return dt, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) {
|
func (ci *ConnInfo) buildReflectTypeToDataType() {
|
||||||
|
ci.reflectTypeToDataType = make(map[reflect.Type]*DataType)
|
||||||
|
|
||||||
|
for _, dt := range ci.oidToDataType {
|
||||||
|
if _, is := dt.Value.(TypeValue); !is {
|
||||||
|
ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for reflectType, name := range ci.reflectTypeToName {
|
||||||
|
if dt, ok := ci.nameToDataType[name]; ok {
|
||||||
|
ci.reflectTypeToDataType[reflectType] = dt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DataTypeForValue finds a data type suitable for v. Use RegisterDataType to register types that can encode and decode
|
||||||
|
// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type.
|
||||||
|
func (ci *ConnInfo) DataTypeForValue(v interface{}) (*DataType, bool) {
|
||||||
|
if ci.reflectTypeToDataType == nil {
|
||||||
|
ci.buildReflectTypeToDataType()
|
||||||
|
}
|
||||||
|
|
||||||
if tv, ok := v.(TypeValue); ok {
|
if tv, ok := v.(TypeValue); ok {
|
||||||
dt, ok := ci.nameToDataType[tv.PgTypeName()]
|
dt, ok := ci.nameToDataType[tv.PgTypeName()]
|
||||||
return dt, ok
|
return dt, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()]
|
dt, ok := ci.reflectTypeToDataType[reflect.TypeOf(v)]
|
||||||
return dt, ok
|
return dt, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -376,13 +448,7 @@ func (ci *ConnInfo) ResultFormatCodeForOID(oid uint32) int16 {
|
|||||||
|
|
||||||
// DeepCopy makes a deep copy of the ConnInfo.
|
// DeepCopy makes a deep copy of the ConnInfo.
|
||||||
func (ci *ConnInfo) DeepCopy() *ConnInfo {
|
func (ci *ConnInfo) DeepCopy() *ConnInfo {
|
||||||
ci2 := &ConnInfo{
|
ci2 := newConnInfo()
|
||||||
oidToDataType: make(map[uint32]*DataType, len(ci.oidToDataType)),
|
|
||||||
nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)),
|
|
||||||
reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)),
|
|
||||||
oidToParamFormatCode: make(map[uint32]int16, len(ci.oidToParamFormatCode)),
|
|
||||||
oidToResultFormatCode: make(map[uint32]int16, len(ci.oidToResultFormatCode)),
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, dt := range ci.oidToDataType {
|
for _, dt := range ci.oidToDataType {
|
||||||
var value Value
|
var value Value
|
||||||
@@ -399,6 +465,10 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for t, n := range ci.reflectTypeToName {
|
||||||
|
ci2.reflectTypeToName[t] = n
|
||||||
|
}
|
||||||
|
|
||||||
return ci2
|
return ci2
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -416,7 +486,19 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac
|
|||||||
return errors.Errorf("unknown format code: %v", formatCode)
|
return errors.Errorf("unknown format code: %v", formatCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
if dt, ok := ci.DataTypeForOID(oid); ok {
|
var dt *DataType
|
||||||
|
|
||||||
|
if oid == 0 {
|
||||||
|
if dataType, ok := ci.DataTypeForValue(dest); ok {
|
||||||
|
dt = dataType
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if dataType, ok := ci.DataTypeForOID(oid); ok {
|
||||||
|
dt = dataType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if dt != nil {
|
||||||
switch formatCode {
|
switch formatCode {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
if dt.binaryDecoder != nil {
|
if dt.binaryDecoder != nil {
|
||||||
|
|||||||
+14
-5
@@ -104,30 +104,39 @@ func (ct *pgCustomType) DecodeText(ci *pgtype.ConnInfo, buf []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnInfoScanUnknownOIDToCustomType(t *testing.T) {
|
func TestConnInfoScanUnregisteredOIDToCustomType(t *testing.T) {
|
||||||
unknownOID := uint32(999999)
|
unregisteredOID := uint32(999999)
|
||||||
ci := pgtype.NewConnInfo()
|
ci := pgtype.NewConnInfo()
|
||||||
|
|
||||||
var ct pgCustomType
|
var ct pgCustomType
|
||||||
err := ci.Scan(unknownOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct)
|
err := ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "foo", ct.a)
|
assert.Equal(t, "foo", ct.a)
|
||||||
assert.Equal(t, "bar", ct.b)
|
assert.Equal(t, "bar", ct.b)
|
||||||
|
|
||||||
// Scan value into pointer to custom type
|
// Scan value into pointer to custom type
|
||||||
var pCt *pgCustomType
|
var pCt *pgCustomType
|
||||||
err = ci.Scan(unknownOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt)
|
err = ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
require.NotNil(t, pCt)
|
require.NotNil(t, pCt)
|
||||||
assert.Equal(t, "foo", pCt.a)
|
assert.Equal(t, "foo", pCt.a)
|
||||||
assert.Equal(t, "bar", pCt.b)
|
assert.Equal(t, "bar", pCt.b)
|
||||||
|
|
||||||
// Scan null into pointer to custom type
|
// Scan null into pointer to custom type
|
||||||
err = ci.Scan(unknownOID, pgx.TextFormatCode, nil, &pCt)
|
err = ci.Scan(unregisteredOID, pgx.TextFormatCode, nil, &pCt)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Nil(t, pCt)
|
assert.Nil(t, pCt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnInfoScanUnknownOIDTextFormat(t *testing.T) {
|
||||||
|
ci := pgtype.NewConnInfo()
|
||||||
|
|
||||||
|
var n int32
|
||||||
|
err := ci.Scan(0, pgx.TextFormatCode, []byte("123"), &n)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 123, n)
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) {
|
func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) {
|
||||||
ci := pgtype.NewConnInfo()
|
ci := pgtype.NewConnInfo()
|
||||||
src := []byte{0, 0, 0, 42}
|
src := []byte{0, 0, 0, 42}
|
||||||
|
|||||||
Reference in New Issue
Block a user