diff --git a/pgtype/array.go b/pgtype/array.go index 174007c1..29d6f803 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -28,6 +28,20 @@ type ArrayDimension struct { LowerBound int32 } +// cardinality returns the number of elements in an array of dimensions size. +func cardinality(dimensions []ArrayDimension) int { + if len(dimensions) == 0 { + return 0 + } + + elementCount := int(dimensions[0].Length) + for _, d := range dimensions[1:] { + elementCount *= int(d.Length) + } + + return elementCount +} + func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { if len(src) < 12 { return 0, fmt.Errorf("array header too short: %d", len(src)) diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go new file mode 100644 index 00000000..b72290a0 --- /dev/null +++ b/pgtype/array_codec.go @@ -0,0 +1,352 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + + "github.com/jackc/pgio" +) + +// ArrayGetter is a type that can be converted into a PostgreSQL array. +type ArrayGetter interface { + // Dimensions returns the array dimensions. If array is nil then nil is returned. + Dimensions() []ArrayDimension + + // Index returns the element at i. + Index(i int) interface{} +} + +// ArraySetter is a type can be set from a PostgreSQL array. +type ArraySetter interface { + // SetDimensions prepares the value such that ScanIndex can be called for each element. dimensions may be nil to + // indicate a NULL array. If unable to exactly preserve dimensions SetDimensions may return an error or silently + // flatten the array dimensions. + SetDimensions(dimensions []ArrayDimension) error + + // ScanIndex returns a value usable as a scan target for i. SetDimensions must be called before ScanIndex. + ScanIndex(i int) interface{} +} + +type int16Array []int16 + +func (a int16Array) Dimensions() []ArrayDimension { + if a == nil { + return nil + } + + return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} +} + +func (a int16Array) Index(i int) interface{} { + return a[i] +} + +func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + a = nil + return nil + } + + elementCount := cardinality(dimensions) + *a = make(int16Array, elementCount) + return nil +} + +func (a int16Array) ScanIndex(i int) interface{} { + return &a[i] +} + +func makeArrayGetter(a interface{}) (ArrayGetter, error) { + switch a := a.(type) { + case ArrayGetter: + return a, nil + case []int16: + return (*int16Array)(&a), nil + } + + return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a) +} + +func makeArraySetter(a interface{}) (ArraySetter, error) { + switch a := a.(type) { + case ArraySetter: + return a, nil + case *[]int16: + return (*int16Array)(a), nil + } + + return nil, fmt.Errorf("cannot convert %T to ArraySetter", a) +} + +// ArrayCodec is a codec for any array type. +type ArrayCodec struct { + ElementCodec Codec + ElementOID uint32 +} + +func (c *ArrayCodec) FormatSupported(format int16) bool { + return c.ElementCodec.FormatSupported(format) +} + +func (c *ArrayCodec) PreferredFormat() int16 { + return c.ElementCodec.PreferredFormat() +} + +func (c *ArrayCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + if value == nil { + return nil, nil + } + + array, err := makeArrayGetter(value) + if err != nil { + return nil, err + } + + switch format { + case BinaryFormatCode: + return c.encodeBinary(ci, oid, array, buf) + case TextFormatCode: + return c.encodeText(ci, oid, array, buf) + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } + +} + +func (c *ArrayCodec) encodeBinary(ci *ConnInfo, oid uint32, array ArrayGetter, buf []byte) (newBuf []byte, err error) { + dimensions := array.Dimensions() + if dimensions == nil { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: dimensions, + ElementOID: int32(c.ElementOID), + } + + containsNullIndex := len(buf) + 4 + + buf = arrayHeader.EncodeBinary(ci, buf) + + elementCount := cardinality(dimensions) + for i := 0; i < elementCount; i++ { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := c.ElementCodec.Encode(ci, c.ElementOID, BinaryFormatCode, array.Index(i), buf) + if err != nil { + return nil, err + } + if elemBuf == nil { + pgio.SetInt32(buf[containsNullIndex:], 1) + } else { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf []byte) (newBuf []byte, err error) { + dimensions := array.Dimensions() + if dimensions == nil { + return nil, nil + } + + if len(dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(dimensions)) + dimElemCounts[len(dimensions)-1] = int(dimensions[len(dimensions)-1].Length) + for i := len(dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + elementCount := cardinality(dimensions) + for i := 0; i < elementCount; i++ { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := c.ElementCodec.Encode(ci, c.ElementOID, TextFormatCode, array.Index(i), inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (c *ArrayCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + _, err := makeArraySetter(target) + if err != nil { + return nil + } + + return (*scanPlanArrayCodec)(c) +} + +func (c *ArrayCodec) decodeBinary(ci *ConnInfo, arrayOID uint32, src []byte, array ArraySetter) error { + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + // TODO - ArrayHeader.DecodeBinary should do this. But doing this there breaks old array code. Leave until old code + // can be removed. + if arrayHeader.Dimensions == nil { + arrayHeader.Dimensions = []ArrayDimension{} + } + + err = array.SetDimensions(arrayHeader.Dimensions) + if err != nil { + return err + } + + elementCount := cardinality(arrayHeader.Dimensions) + if elementCount == 0 { + return nil + } + + elementScanPlan := c.ElementCodec.PlanScan(ci, c.ElementOID, BinaryFormatCode, array.ScanIndex(0), false) + if elementScanPlan == nil { + elementScanPlan = ci.PlanScan(c.ElementOID, BinaryFormatCode, array.ScanIndex(0)) + } + + for i := 0; i < elementCount; i++ { + elem := array.ScanIndex(i) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elementScanPlan.Scan(ci, c.ElementOID, BinaryFormatCode, elemSrc, elem) + if err != nil { + return err + } + } + + return nil +} + +func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array ArraySetter) error { + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + // TODO - ParseUntypedTextArray should do this. But doing this there breaks old array code. Leave until old code + // can be removed. + if uta.Dimensions == nil { + uta.Dimensions = []ArrayDimension{} + } + + err = array.SetDimensions(uta.Dimensions) + if err != nil { + return err + } + + if len(uta.Elements) == 0 { + return nil + } + + elementScanPlan := c.ElementCodec.PlanScan(ci, c.ElementOID, TextFormatCode, array.ScanIndex(0), false) + if elementScanPlan == nil { + elementScanPlan = ci.PlanScan(c.ElementOID, TextFormatCode, array.ScanIndex(0)) + } + + for i, s := range uta.Elements { + elem := array.ScanIndex(i) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + + err = elementScanPlan.Scan(ci, c.ElementOID, TextFormatCode, elemSrc, elem) + if err != nil { + return err + } + } + + return nil +} + +type scanPlanArrayCodec ArrayCodec + +func (spac *scanPlanArrayCodec) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + c := (*ArrayCodec)(spac) + + array, err := makeArraySetter(dst) + if err != nil { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + + if src == nil { + return array.SetDimensions(nil) + } + + switch formatCode { + case BinaryFormatCode: + return c.decodeBinary(ci, oid, src, array) + case TextFormatCode: + return c.decodeText(ci, oid, src, array) + default: + return fmt.Errorf("unknown format code %d", formatCode) + } +} + +func (c ArrayCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + // var n int64 + // err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) + // return n, err + + return nil, fmt.Errorf("not implemented") +} + +func (c ArrayCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + // var n int16 + // err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) + // return n, err + + return nil, fmt.Errorf("not implemented") +} diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go new file mode 100644 index 00000000..f213d0ec --- /dev/null +++ b/pgtype/array_codec_test.go @@ -0,0 +1,106 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/stretchr/testify/assert" +) + +func TestArrayCodec(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + tests := []struct { + expected []int16 + }{ + {[]int16(nil)}, + {[]int16{}}, + {[]int16{1, 2, 3}}, + } + for i, tt := range tests { + var actual []int16 + err := conn.QueryRow( + context.Background(), + "select $1::smallint[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } +} + +// func TestArrayCodecValue(t *testing.T) { +// ArrayCodec := pgtype.NewArrayCodec("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }) + +// err := ArrayCodec.Set(nil) +// require.NoError(t, err) + +// gotValue := ArrayCodec.Get() +// require.Nil(t, gotValue) + +// slice := []string{"foo", "bar"} +// err = ArrayCodec.AssignTo(&slice) +// require.NoError(t, err) +// require.Nil(t, slice) + +// err = ArrayCodec.Set([]string{}) +// require.NoError(t, err) + +// gotValue = ArrayCodec.Get() +// require.Len(t, gotValue, 0) + +// err = ArrayCodec.AssignTo(&slice) +// require.NoError(t, err) +// require.EqualValues(t, []string{}, slice) + +// err = ArrayCodec.Set([]string{"baz", "quz"}) +// require.NoError(t, err) + +// gotValue = ArrayCodec.Get() +// require.Len(t, gotValue, 2) + +// err = ArrayCodec.AssignTo(&slice) +// require.NoError(t, err) +// require.EqualValues(t, []string{"baz", "quz"}, slice) +// } + +// func TestArrayCodecTranscode(t *testing.T) { +// conn := testutil.MustConnectPgx(t) +// defer testutil.MustCloseContext(t, conn) + +// conn.ConnInfo().RegisterDataType(pgtype.DataType{ +// Value: pgtype.NewArrayCodec("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), +// Name: "_text", +// OID: pgtype.TextArrayOID, +// }) + +// var dstStrings []string +// err := conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) +// require.NoError(t, err) + +// require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) +// } + +// func TestArrayCodecEmptyArrayDoesNotBreakArrayCodec(t *testing.T) { +// conn := testutil.MustConnectPgx(t) +// defer testutil.MustCloseContext(t, conn) + +// conn.ConnInfo().RegisterDataType(pgtype.DataType{ +// Value: pgtype.NewArrayCodec("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), +// Name: "_text", +// OID: pgtype.TextArrayOID, +// }) + +// var dstStrings []string +// err := conn.QueryRow(context.Background(), "select '{}'::text[]").Scan(&dstStrings) +// require.NoError(t, err) + +// require.EqualValues(t, []string{}, dstStrings) + +// err = conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) +// require.NoError(t, err) + +// require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) +// } diff --git a/pgtype/convert.go b/pgtype/convert.go index 21e208f5..ee5ba393 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -5,6 +5,7 @@ import ( "fmt" "math" "reflect" + "strconv" "time" ) @@ -452,6 +453,141 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { return nil, false } +func convertToInt64ForEncode(v interface{}) (n int64, valid bool, err error) { + if v == nil { + return 0, false, nil + } + + switch v := v.(type) { + case int8: + return int64(v), true, nil + case uint8: + return int64(v), true, nil + case int16: + return int64(v), true, nil + case uint16: + return int64(v), true, nil + case int32: + return int64(v), true, nil + case uint32: + return int64(v), true, nil + case int64: + return int64(v), true, nil + case uint64: + if v > math.MaxInt64 { + return 0, false, fmt.Errorf("%d is greater than maximum value for int64", v) + } + return int64(v), true, nil + case int: + return int64(v), true, nil + case uint: + if v > math.MaxInt64 { + return 0, false, fmt.Errorf("%d is greater than maximum value for int64", v) + } + return int64(v), true, nil + case string: + num, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, false, err + } + return num, true, nil + case float32: + if v > math.MaxInt64 { + return 0, false, fmt.Errorf("%f is greater than maximum value for int64", v) + } + return int64(v), true, nil + case float64: + if v > math.MaxInt64 { + return 0, false, fmt.Errorf("%f is greater than maximum value for int64", v) + } + return int64(v), true, nil + case *int8: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *uint8: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *int16: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *uint16: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *int32: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *uint32: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *int64: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *uint64: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *int: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *uint: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *string: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *float32: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + case *float64: + if v == nil { + return 0, false, nil + } else { + return convertToInt64ForEncode(*v) + } + + default: + if originalvalue, ok := underlyingNumberType(v); ok { + return convertToInt64ForEncode(originalvalue) + } + return 0, false, fmt.Errorf("cannot convert %v to int64", v) + } +} + func init() { kindTypes = map[reflect.Kind]reflect.Type{ reflect.Bool: reflect.TypeOf(false), diff --git a/pgtype/int2_codec.go b/pgtype/int2_codec.go new file mode 100644 index 00000000..7ea50870 --- /dev/null +++ b/pgtype/int2_codec.go @@ -0,0 +1,146 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Int2Codec struct{} + +func (Int2Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Int2Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Int2Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) + } + if !valid { + return nil, nil + } + + if n > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n) + } + if n < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n) + } + + switch format { + case BinaryFormatCode: + return pgio.AppendInt16(buf, int16(n)), nil + case TextFormatCode: + return append(buf, strconv.FormatInt(n, 10)...), nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } +} + +func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case BinaryFormatCode: + case TextFormatCode: + switch target.(type) { + case *int16: + return scanPlanTextToAnyInt16{} + case *int32: + return scanPlanTextToAnyInt32{} + case *int64: + return scanPlanTextToAnyInt64{} + } + } + + return nil +} + +func (c Int2Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n int64 + err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) + return n, err +} + +func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var n int16 + err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) + return n, err +} + +type scanPlanTextToAnyInt16 struct{} + +func (scanPlanTextToAnyInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + + *p = int16(n) + return nil +} + +type scanPlanTextToAnyInt32 struct{} + +func (scanPlanTextToAnyInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int32) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 32) + if err != nil { + return err + } + + *p = int32(n) + return nil +} + +type scanPlanTextToAnyInt64 struct{} + +func (scanPlanTextToAnyInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int64) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + *p = int64(n) + return nil +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d8dd5abf..b0b07663 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -2,7 +2,9 @@ package pgtype import ( "database/sql" + "database/sql/driver" "encoding/binary" + "errors" "fmt" "math" "net" @@ -173,6 +175,34 @@ type ResultDecoder interface { DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error } +type Encoder interface { + // Encode appends the encoded bytes of value to buf. If value is the SQL NULL then append nothing and return + // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data + // written. + Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) +} + +type Codec interface { + // FormatSupported returns true if the format is supported. + FormatSupported(int16) bool + + // PreferredFormat returns the preferred format. + PreferredFormat() int16 + + Encoder + + // PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If + // actualTarget is true then the returned ScanPlan may be optimized to directly scan into target. If no plan can be + // found then nil is returned. + PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan + + // DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface. + DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) + + // DecodeValue returns src decoded into its default format. + DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) +} + // ResultFormatPreferrer allows a type to specify its preferred result format instead of it being inferred from // whether it is also a BinaryDecoder. type ResultFormatPreferrer interface { @@ -229,6 +259,8 @@ type DataType struct { textDecoder TextDecoder binaryDecoder BinaryDecoder + Codec Codec + Name string OID uint32 } @@ -268,7 +300,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID}) - ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID}) + ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}}) ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID}) ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID}) ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) @@ -292,7 +324,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID}) - ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID}) + ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID}) ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID}) @@ -752,6 +784,15 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt // PlanScan prepares a plan to scan a value into dst. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { + if oid != 0 { + if dt, ok := ci.DataTypeForOID(oid); ok && dt.Codec != nil { + plan := dt.Codec.PlanScan(ci, oid, formatCode, dst, false) + if plan != nil { + return plan + } + } + } + switch formatCode { case BinaryFormatCode: switch dst.(type) { @@ -866,6 +907,8 @@ func NewValue(v Value) Value { } } +var ErrScanTargetTypeChanged = errors.New("scan target type changed") + var nameValues map[string]Value func init() {