Initial codec support for int2 and int2[]
This commit is contained in:
@@ -28,6 +28,20 @@ type ArrayDimension struct {
|
|||||||
LowerBound int32
|
LowerBound int32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cardinality returns the number of elements in an array of dimensions size.
|
||||||
|
func cardinality(dimensions []ArrayDimension) int {
|
||||||
|
if len(dimensions) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
elementCount := int(dimensions[0].Length)
|
||||||
|
for _, d := range dimensions[1:] {
|
||||||
|
elementCount *= int(d.Length)
|
||||||
|
}
|
||||||
|
|
||||||
|
return elementCount
|
||||||
|
}
|
||||||
|
|
||||||
func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) {
|
func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) {
|
||||||
if len(src) < 12 {
|
if len(src) < 12 {
|
||||||
return 0, fmt.Errorf("array header too short: %d", len(src))
|
return 0, fmt.Errorf("array header too short: %d", len(src))
|
||||||
|
|||||||
@@ -0,0 +1,352 @@
|
|||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ArrayGetter is a type that can be converted into a PostgreSQL array.
|
||||||
|
type ArrayGetter interface {
|
||||||
|
// Dimensions returns the array dimensions. If array is nil then nil is returned.
|
||||||
|
Dimensions() []ArrayDimension
|
||||||
|
|
||||||
|
// Index returns the element at i.
|
||||||
|
Index(i int) interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ArraySetter is a type can be set from a PostgreSQL array.
|
||||||
|
type ArraySetter interface {
|
||||||
|
// SetDimensions prepares the value such that ScanIndex can be called for each element. dimensions may be nil to
|
||||||
|
// indicate a NULL array. If unable to exactly preserve dimensions SetDimensions may return an error or silently
|
||||||
|
// flatten the array dimensions.
|
||||||
|
SetDimensions(dimensions []ArrayDimension) error
|
||||||
|
|
||||||
|
// ScanIndex returns a value usable as a scan target for i. SetDimensions must be called before ScanIndex.
|
||||||
|
ScanIndex(i int) interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type int16Array []int16
|
||||||
|
|
||||||
|
func (a int16Array) Dimensions() []ArrayDimension {
|
||||||
|
if a == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a int16Array) Index(i int) interface{} {
|
||||||
|
return a[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error {
|
||||||
|
if dimensions == nil {
|
||||||
|
a = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
elementCount := cardinality(dimensions)
|
||||||
|
*a = make(int16Array, elementCount)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a int16Array) ScanIndex(i int) interface{} {
|
||||||
|
return &a[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeArrayGetter(a interface{}) (ArrayGetter, error) {
|
||||||
|
switch a := a.(type) {
|
||||||
|
case ArrayGetter:
|
||||||
|
return a, nil
|
||||||
|
case []int16:
|
||||||
|
return (*int16Array)(&a), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeArraySetter(a interface{}) (ArraySetter, error) {
|
||||||
|
switch a := a.(type) {
|
||||||
|
case ArraySetter:
|
||||||
|
return a, nil
|
||||||
|
case *[]int16:
|
||||||
|
return (*int16Array)(a), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("cannot convert %T to ArraySetter", a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ArrayCodec is a codec for any array type.
|
||||||
|
type ArrayCodec struct {
|
||||||
|
ElementCodec Codec
|
||||||
|
ElementOID uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ArrayCodec) FormatSupported(format int16) bool {
|
||||||
|
return c.ElementCodec.FormatSupported(format)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ArrayCodec) PreferredFormat() int16 {
|
||||||
|
return c.ElementCodec.PreferredFormat()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ArrayCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) {
|
||||||
|
if value == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
array, err := makeArrayGetter(value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch format {
|
||||||
|
case BinaryFormatCode:
|
||||||
|
return c.encodeBinary(ci, oid, array, buf)
|
||||||
|
case TextFormatCode:
|
||||||
|
return c.encodeText(ci, oid, array, buf)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown format code: %v", format)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ArrayCodec) encodeBinary(ci *ConnInfo, oid uint32, array ArrayGetter, buf []byte) (newBuf []byte, err error) {
|
||||||
|
dimensions := array.Dimensions()
|
||||||
|
if dimensions == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayHeader := ArrayHeader{
|
||||||
|
Dimensions: dimensions,
|
||||||
|
ElementOID: int32(c.ElementOID),
|
||||||
|
}
|
||||||
|
|
||||||
|
containsNullIndex := len(buf) + 4
|
||||||
|
|
||||||
|
buf = arrayHeader.EncodeBinary(ci, buf)
|
||||||
|
|
||||||
|
elementCount := cardinality(dimensions)
|
||||||
|
for i := 0; i < elementCount; i++ {
|
||||||
|
sp := len(buf)
|
||||||
|
buf = pgio.AppendInt32(buf, -1)
|
||||||
|
|
||||||
|
elemBuf, err := c.ElementCodec.Encode(ci, c.ElementOID, BinaryFormatCode, array.Index(i), buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if elemBuf == nil {
|
||||||
|
pgio.SetInt32(buf[containsNullIndex:], 1)
|
||||||
|
} else {
|
||||||
|
buf = elemBuf
|
||||||
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf []byte) (newBuf []byte, err error) {
|
||||||
|
dimensions := array.Dimensions()
|
||||||
|
if dimensions == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(dimensions) == 0 {
|
||||||
|
return append(buf, '{', '}'), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = EncodeTextArrayDimensions(buf, dimensions)
|
||||||
|
|
||||||
|
// dimElemCounts is the multiples of elements that each array lies on. For
|
||||||
|
// example, a single dimension array of length 4 would have a dimElemCounts of
|
||||||
|
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
|
||||||
|
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
|
||||||
|
// or '}'.
|
||||||
|
dimElemCounts := make([]int, len(dimensions))
|
||||||
|
dimElemCounts[len(dimensions)-1] = int(dimensions[len(dimensions)-1].Length)
|
||||||
|
for i := len(dimensions) - 2; i > -1; i-- {
|
||||||
|
dimElemCounts[i] = int(dimensions[i].Length) * dimElemCounts[i+1]
|
||||||
|
}
|
||||||
|
|
||||||
|
inElemBuf := make([]byte, 0, 32)
|
||||||
|
elementCount := cardinality(dimensions)
|
||||||
|
for i := 0; i < elementCount; i++ {
|
||||||
|
if i > 0 {
|
||||||
|
buf = append(buf, ',')
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dec := range dimElemCounts {
|
||||||
|
if i%dec == 0 {
|
||||||
|
buf = append(buf, '{')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elemBuf, err := c.ElementCodec.Encode(ci, c.ElementOID, TextFormatCode, array.Index(i), inElemBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if elemBuf == nil {
|
||||||
|
buf = append(buf, `NULL`...)
|
||||||
|
} else {
|
||||||
|
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dec := range dimElemCounts {
|
||||||
|
if (i+1)%dec == 0 {
|
||||||
|
buf = append(buf, '}')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ArrayCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan {
|
||||||
|
_, err := makeArraySetter(target)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return (*scanPlanArrayCodec)(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ArrayCodec) decodeBinary(ci *ConnInfo, arrayOID uint32, src []byte, array ArraySetter) error {
|
||||||
|
var arrayHeader ArrayHeader
|
||||||
|
rp, err := arrayHeader.DecodeBinary(ci, src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO - ArrayHeader.DecodeBinary should do this. But doing this there breaks old array code. Leave until old code
|
||||||
|
// can be removed.
|
||||||
|
if arrayHeader.Dimensions == nil {
|
||||||
|
arrayHeader.Dimensions = []ArrayDimension{}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = array.SetDimensions(arrayHeader.Dimensions)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
elementCount := cardinality(arrayHeader.Dimensions)
|
||||||
|
if elementCount == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
elementScanPlan := c.ElementCodec.PlanScan(ci, c.ElementOID, BinaryFormatCode, array.ScanIndex(0), false)
|
||||||
|
if elementScanPlan == nil {
|
||||||
|
elementScanPlan = ci.PlanScan(c.ElementOID, BinaryFormatCode, array.ScanIndex(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < elementCount; i++ {
|
||||||
|
elem := array.ScanIndex(i)
|
||||||
|
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||||
|
rp += 4
|
||||||
|
var elemSrc []byte
|
||||||
|
if elemLen >= 0 {
|
||||||
|
elemSrc = src[rp : rp+elemLen]
|
||||||
|
rp += elemLen
|
||||||
|
}
|
||||||
|
err = elementScanPlan.Scan(ci, c.ElementOID, BinaryFormatCode, elemSrc, elem)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array ArraySetter) error {
|
||||||
|
uta, err := ParseUntypedTextArray(string(src))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO - ParseUntypedTextArray should do this. But doing this there breaks old array code. Leave until old code
|
||||||
|
// can be removed.
|
||||||
|
if uta.Dimensions == nil {
|
||||||
|
uta.Dimensions = []ArrayDimension{}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = array.SetDimensions(uta.Dimensions)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(uta.Elements) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
elementScanPlan := c.ElementCodec.PlanScan(ci, c.ElementOID, TextFormatCode, array.ScanIndex(0), false)
|
||||||
|
if elementScanPlan == nil {
|
||||||
|
elementScanPlan = ci.PlanScan(c.ElementOID, TextFormatCode, array.ScanIndex(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, s := range uta.Elements {
|
||||||
|
elem := array.ScanIndex(i)
|
||||||
|
var elemSrc []byte
|
||||||
|
if s != "NULL" {
|
||||||
|
elemSrc = []byte(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = elementScanPlan.Scan(ci, c.ElementOID, TextFormatCode, elemSrc, elem)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type scanPlanArrayCodec ArrayCodec
|
||||||
|
|
||||||
|
func (spac *scanPlanArrayCodec) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
|
||||||
|
c := (*ArrayCodec)(spac)
|
||||||
|
|
||||||
|
array, err := makeArraySetter(dst)
|
||||||
|
if err != nil {
|
||||||
|
newPlan := ci.PlanScan(oid, formatCode, dst)
|
||||||
|
return newPlan.Scan(ci, oid, formatCode, src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
if src == nil {
|
||||||
|
return array.SetDimensions(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch formatCode {
|
||||||
|
case BinaryFormatCode:
|
||||||
|
return c.decodeBinary(ci, oid, src, array)
|
||||||
|
case TextFormatCode:
|
||||||
|
return c.decodeText(ci, oid, src, array)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown format code %d", formatCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ArrayCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||||
|
if src == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// var n int64
|
||||||
|
// err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n)
|
||||||
|
// return n, err
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ArrayCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) {
|
||||||
|
if src == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// var n int16
|
||||||
|
// err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n)
|
||||||
|
// return n, err
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("not implemented")
|
||||||
|
}
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
package pgtype_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgtype/testutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestArrayCodec(t *testing.T) {
|
||||||
|
conn := testutil.MustConnectPgx(t)
|
||||||
|
defer testutil.MustCloseContext(t, conn)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
expected []int16
|
||||||
|
}{
|
||||||
|
{[]int16(nil)},
|
||||||
|
{[]int16{}},
|
||||||
|
{[]int16{1, 2, 3}},
|
||||||
|
}
|
||||||
|
for i, tt := range tests {
|
||||||
|
var actual []int16
|
||||||
|
err := conn.QueryRow(
|
||||||
|
context.Background(),
|
||||||
|
"select $1::smallint[]",
|
||||||
|
tt.expected,
|
||||||
|
).Scan(&actual)
|
||||||
|
assert.NoErrorf(t, err, "%d", i)
|
||||||
|
assert.Equalf(t, tt.expected, actual, "%d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// func TestArrayCodecValue(t *testing.T) {
|
||||||
|
// ArrayCodec := pgtype.NewArrayCodec("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} })
|
||||||
|
|
||||||
|
// err := ArrayCodec.Set(nil)
|
||||||
|
// require.NoError(t, err)
|
||||||
|
|
||||||
|
// gotValue := ArrayCodec.Get()
|
||||||
|
// require.Nil(t, gotValue)
|
||||||
|
|
||||||
|
// slice := []string{"foo", "bar"}
|
||||||
|
// err = ArrayCodec.AssignTo(&slice)
|
||||||
|
// require.NoError(t, err)
|
||||||
|
// require.Nil(t, slice)
|
||||||
|
|
||||||
|
// err = ArrayCodec.Set([]string{})
|
||||||
|
// require.NoError(t, err)
|
||||||
|
|
||||||
|
// gotValue = ArrayCodec.Get()
|
||||||
|
// require.Len(t, gotValue, 0)
|
||||||
|
|
||||||
|
// err = ArrayCodec.AssignTo(&slice)
|
||||||
|
// require.NoError(t, err)
|
||||||
|
// require.EqualValues(t, []string{}, slice)
|
||||||
|
|
||||||
|
// err = ArrayCodec.Set([]string{"baz", "quz"})
|
||||||
|
// require.NoError(t, err)
|
||||||
|
|
||||||
|
// gotValue = ArrayCodec.Get()
|
||||||
|
// require.Len(t, gotValue, 2)
|
||||||
|
|
||||||
|
// err = ArrayCodec.AssignTo(&slice)
|
||||||
|
// require.NoError(t, err)
|
||||||
|
// require.EqualValues(t, []string{"baz", "quz"}, slice)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestArrayCodecTranscode(t *testing.T) {
|
||||||
|
// conn := testutil.MustConnectPgx(t)
|
||||||
|
// defer testutil.MustCloseContext(t, conn)
|
||||||
|
|
||||||
|
// conn.ConnInfo().RegisterDataType(pgtype.DataType{
|
||||||
|
// Value: pgtype.NewArrayCodec("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }),
|
||||||
|
// Name: "_text",
|
||||||
|
// OID: pgtype.TextArrayOID,
|
||||||
|
// })
|
||||||
|
|
||||||
|
// var dstStrings []string
|
||||||
|
// err := conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings)
|
||||||
|
// require.NoError(t, err)
|
||||||
|
|
||||||
|
// require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestArrayCodecEmptyArrayDoesNotBreakArrayCodec(t *testing.T) {
|
||||||
|
// conn := testutil.MustConnectPgx(t)
|
||||||
|
// defer testutil.MustCloseContext(t, conn)
|
||||||
|
|
||||||
|
// conn.ConnInfo().RegisterDataType(pgtype.DataType{
|
||||||
|
// Value: pgtype.NewArrayCodec("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }),
|
||||||
|
// Name: "_text",
|
||||||
|
// OID: pgtype.TextArrayOID,
|
||||||
|
// })
|
||||||
|
|
||||||
|
// var dstStrings []string
|
||||||
|
// err := conn.QueryRow(context.Background(), "select '{}'::text[]").Scan(&dstStrings)
|
||||||
|
// require.NoError(t, err)
|
||||||
|
|
||||||
|
// require.EqualValues(t, []string{}, dstStrings)
|
||||||
|
|
||||||
|
// err = conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings)
|
||||||
|
// require.NoError(t, err)
|
||||||
|
|
||||||
|
// require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings)
|
||||||
|
// }
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -452,6 +453,141 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertToInt64ForEncode(v interface{}) (n int64, valid bool, err error) {
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := v.(type) {
|
||||||
|
case int8:
|
||||||
|
return int64(v), true, nil
|
||||||
|
case uint8:
|
||||||
|
return int64(v), true, nil
|
||||||
|
case int16:
|
||||||
|
return int64(v), true, nil
|
||||||
|
case uint16:
|
||||||
|
return int64(v), true, nil
|
||||||
|
case int32:
|
||||||
|
return int64(v), true, nil
|
||||||
|
case uint32:
|
||||||
|
return int64(v), true, nil
|
||||||
|
case int64:
|
||||||
|
return int64(v), true, nil
|
||||||
|
case uint64:
|
||||||
|
if v > math.MaxInt64 {
|
||||||
|
return 0, false, fmt.Errorf("%d is greater than maximum value for int64", v)
|
||||||
|
}
|
||||||
|
return int64(v), true, nil
|
||||||
|
case int:
|
||||||
|
return int64(v), true, nil
|
||||||
|
case uint:
|
||||||
|
if v > math.MaxInt64 {
|
||||||
|
return 0, false, fmt.Errorf("%d is greater than maximum value for int64", v)
|
||||||
|
}
|
||||||
|
return int64(v), true, nil
|
||||||
|
case string:
|
||||||
|
num, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
return num, true, nil
|
||||||
|
case float32:
|
||||||
|
if v > math.MaxInt64 {
|
||||||
|
return 0, false, fmt.Errorf("%f is greater than maximum value for int64", v)
|
||||||
|
}
|
||||||
|
return int64(v), true, nil
|
||||||
|
case float64:
|
||||||
|
if v > math.MaxInt64 {
|
||||||
|
return 0, false, fmt.Errorf("%f is greater than maximum value for int64", v)
|
||||||
|
}
|
||||||
|
return int64(v), true, nil
|
||||||
|
case *int8:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
} else {
|
||||||
|
return convertToInt64ForEncode(*v)
|
||||||
|
}
|
||||||
|
case *uint8:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
} else {
|
||||||
|
return convertToInt64ForEncode(*v)
|
||||||
|
}
|
||||||
|
case *int16:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
} else {
|
||||||
|
return convertToInt64ForEncode(*v)
|
||||||
|
}
|
||||||
|
case *uint16:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
} else {
|
||||||
|
return convertToInt64ForEncode(*v)
|
||||||
|
}
|
||||||
|
case *int32:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
} else {
|
||||||
|
return convertToInt64ForEncode(*v)
|
||||||
|
}
|
||||||
|
case *uint32:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
} else {
|
||||||
|
return convertToInt64ForEncode(*v)
|
||||||
|
}
|
||||||
|
case *int64:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
} else {
|
||||||
|
return convertToInt64ForEncode(*v)
|
||||||
|
}
|
||||||
|
case *uint64:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
} else {
|
||||||
|
return convertToInt64ForEncode(*v)
|
||||||
|
}
|
||||||
|
case *int:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
} else {
|
||||||
|
return convertToInt64ForEncode(*v)
|
||||||
|
}
|
||||||
|
case *uint:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
} else {
|
||||||
|
return convertToInt64ForEncode(*v)
|
||||||
|
}
|
||||||
|
case *string:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
} else {
|
||||||
|
return convertToInt64ForEncode(*v)
|
||||||
|
}
|
||||||
|
case *float32:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
} else {
|
||||||
|
return convertToInt64ForEncode(*v)
|
||||||
|
}
|
||||||
|
case *float64:
|
||||||
|
if v == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
} else {
|
||||||
|
return convertToInt64ForEncode(*v)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
if originalvalue, ok := underlyingNumberType(v); ok {
|
||||||
|
return convertToInt64ForEncode(originalvalue)
|
||||||
|
}
|
||||||
|
return 0, false, fmt.Errorf("cannot convert %v to int64", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
kindTypes = map[reflect.Kind]reflect.Type{
|
kindTypes = map[reflect.Kind]reflect.Type{
|
||||||
reflect.Bool: reflect.TypeOf(false),
|
reflect.Bool: reflect.TypeOf(false),
|
||||||
|
|||||||
@@ -0,0 +1,146 @@
|
|||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Int2Codec struct{}
|
||||||
|
|
||||||
|
func (Int2Codec) FormatSupported(format int16) bool {
|
||||||
|
return format == TextFormatCode || format == BinaryFormatCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Int2Codec) PreferredFormat() int16 {
|
||||||
|
return BinaryFormatCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Int2Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) {
|
||||||
|
n, valid, err := convertToInt64ForEncode(value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err)
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if n > math.MaxInt16 {
|
||||||
|
return nil, fmt.Errorf("%d is greater than maximum value for int2", n)
|
||||||
|
}
|
||||||
|
if n < math.MinInt16 {
|
||||||
|
return nil, fmt.Errorf("%d is less than minimum value for int2", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch format {
|
||||||
|
case BinaryFormatCode:
|
||||||
|
return pgio.AppendInt16(buf, int16(n)), nil
|
||||||
|
case TextFormatCode:
|
||||||
|
return append(buf, strconv.FormatInt(n, 10)...), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown format code: %v", format)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan {
|
||||||
|
switch format {
|
||||||
|
case BinaryFormatCode:
|
||||||
|
case TextFormatCode:
|
||||||
|
switch target.(type) {
|
||||||
|
case *int16:
|
||||||
|
return scanPlanTextToAnyInt16{}
|
||||||
|
case *int32:
|
||||||
|
return scanPlanTextToAnyInt32{}
|
||||||
|
case *int64:
|
||||||
|
return scanPlanTextToAnyInt64{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Int2Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||||
|
if src == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) {
|
||||||
|
if src == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int16
|
||||||
|
err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type scanPlanTextToAnyInt16 struct{}
|
||||||
|
|
||||||
|
func (scanPlanTextToAnyInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
return fmt.Errorf("cannot scan null into %T", dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, ok := (dst).(*int16)
|
||||||
|
if !ok {
|
||||||
|
return ErrScanTargetTypeChanged
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := strconv.ParseInt(string(src), 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*p = int16(n)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type scanPlanTextToAnyInt32 struct{}
|
||||||
|
|
||||||
|
func (scanPlanTextToAnyInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
return fmt.Errorf("cannot scan null into %T", dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, ok := (dst).(*int32)
|
||||||
|
if !ok {
|
||||||
|
return ErrScanTargetTypeChanged
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := strconv.ParseInt(string(src), 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*p = int32(n)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type scanPlanTextToAnyInt64 struct{}
|
||||||
|
|
||||||
|
func (scanPlanTextToAnyInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
return fmt.Errorf("cannot scan null into %T", dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, ok := (dst).(*int64)
|
||||||
|
if !ok {
|
||||||
|
return ErrScanTargetTypeChanged
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := strconv.ParseInt(string(src), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*p = int64(n)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
+45
-2
@@ -2,7 +2,9 @@ package pgtype
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
@@ -173,6 +175,34 @@ type ResultDecoder interface {
|
|||||||
DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error
|
DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Encoder interface {
|
||||||
|
// Encode appends the encoded bytes of value to buf. If value is the SQL NULL then append nothing and return
|
||||||
|
// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data
|
||||||
|
// written.
|
||||||
|
Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Codec interface {
|
||||||
|
// FormatSupported returns true if the format is supported.
|
||||||
|
FormatSupported(int16) bool
|
||||||
|
|
||||||
|
// PreferredFormat returns the preferred format.
|
||||||
|
PreferredFormat() int16
|
||||||
|
|
||||||
|
Encoder
|
||||||
|
|
||||||
|
// PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If
|
||||||
|
// actualTarget is true then the returned ScanPlan may be optimized to directly scan into target. If no plan can be
|
||||||
|
// found then nil is returned.
|
||||||
|
PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan
|
||||||
|
|
||||||
|
// DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface.
|
||||||
|
DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error)
|
||||||
|
|
||||||
|
// DecodeValue returns src decoded into its default format.
|
||||||
|
DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error)
|
||||||
|
}
|
||||||
|
|
||||||
// ResultFormatPreferrer allows a type to specify its preferred result format instead of it being inferred from
|
// ResultFormatPreferrer allows a type to specify its preferred result format instead of it being inferred from
|
||||||
// whether it is also a BinaryDecoder.
|
// whether it is also a BinaryDecoder.
|
||||||
type ResultFormatPreferrer interface {
|
type ResultFormatPreferrer interface {
|
||||||
@@ -229,6 +259,8 @@ type DataType struct {
|
|||||||
textDecoder TextDecoder
|
textDecoder TextDecoder
|
||||||
binaryDecoder BinaryDecoder
|
binaryDecoder BinaryDecoder
|
||||||
|
|
||||||
|
Codec Codec
|
||||||
|
|
||||||
Name string
|
Name string
|
||||||
OID uint32
|
OID uint32
|
||||||
}
|
}
|
||||||
@@ -268,7 +300,7 @@ func NewConnInfo() *ConnInfo {
|
|||||||
ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID})
|
ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID})
|
||||||
ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID})
|
ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID})
|
||||||
ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID})
|
ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID})
|
||||||
ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID})
|
ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}})
|
||||||
ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID})
|
ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID})
|
||||||
ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID})
|
ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID})
|
||||||
ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID})
|
ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID})
|
||||||
@@ -292,7 +324,7 @@ func NewConnInfo() *ConnInfo {
|
|||||||
ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID})
|
ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID})
|
||||||
ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID})
|
ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID})
|
||||||
ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID})
|
ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID})
|
||||||
ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID})
|
ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID, Codec: Int2Codec{}})
|
||||||
ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID})
|
ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID})
|
||||||
ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID})
|
ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID})
|
||||||
ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID})
|
ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID})
|
||||||
@@ -752,6 +784,15 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt
|
|||||||
|
|
||||||
// PlanScan prepares a plan to scan a value into dst.
|
// PlanScan prepares a plan to scan a value into dst.
|
||||||
func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan {
|
func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan {
|
||||||
|
if oid != 0 {
|
||||||
|
if dt, ok := ci.DataTypeForOID(oid); ok && dt.Codec != nil {
|
||||||
|
plan := dt.Codec.PlanScan(ci, oid, formatCode, dst, false)
|
||||||
|
if plan != nil {
|
||||||
|
return plan
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch formatCode {
|
switch formatCode {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch dst.(type) {
|
switch dst.(type) {
|
||||||
@@ -866,6 +907,8 @@ func NewValue(v Value) Value {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrScanTargetTypeChanged = errors.New("scan target type changed")
|
||||||
|
|
||||||
var nameValues map[string]Value
|
var nameValues map[string]Value
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
Reference in New Issue
Block a user