From 9fc8f9b3a8b0a82bc271c7450ec2ecbe6e5f21a4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 30 Dec 2021 18:12:47 -0600 Subject: [PATCH] Initial passing tests for main pgx package --- extended_query_builder.go | 37 +- pgtype/array_codec.go | 55 +-- pgtype/int2.go | 265 +---------- pgtype/int2_array.go | 896 -------------------------------------- pgtype/int2_codec.go | 223 +++++++++- pgtype/int2_test.go | 213 ++++----- pgtype/pgtype.go | 201 +++++---- pgtype/typed_array_gen.sh | 1 - pgtype/zzz.int2.go | 35 -- query_test.go | 73 ++-- rows.go | 49 ++- values.go | 93 ++-- 12 files changed, 574 insertions(+), 1567 deletions(-) delete mode 100644 pgtype/int2_array.go delete mode 100644 pgtype/zzz.int2.go diff --git a/extended_query_builder.go b/extended_query_builder.go index 1420a808..480e35d3 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -113,23 +113,34 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o } if dt, ok := ci.DataTypeForOID(oid); ok { - value := dt.Value - err := value.Set(arg) - if err != nil { - { - if arg, ok := arg.(driver.Valuer); ok { - v, err := callValuerValue(arg) - if err != nil { - return nil, err + if dt.Value != nil { + value := dt.Value + err := value.Set(arg) + if err != nil { + { + if arg, ok := arg.(driver.Valuer); ok { + v, err := callValuerValue(arg) + if err != nil { + return nil, err + } + return eqb.encodeExtendedParamValue(ci, oid, formatCode, v) } - return eqb.encodeExtendedParamValue(ci, oid, formatCode, v) } + + return nil, err } - - return nil, err + return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) + } else if dt.Codec != nil { + buf, err := dt.Codec.Encode(ci, oid, formatCode, arg, eqb.paramValueBytes) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil } - - return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) } // There is no data type registered for the destination OID, but maybe there is data type registered for the arg diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index b72290a0..16ce7382 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -28,57 +28,6 @@ type ArraySetter interface { ScanIndex(i int) interface{} } -type int16Array []int16 - -func (a int16Array) Dimensions() []ArrayDimension { - if a == nil { - return nil - } - - return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} -} - -func (a int16Array) Index(i int) interface{} { - return a[i] -} - -func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error { - if dimensions == nil { - a = nil - return nil - } - - elementCount := cardinality(dimensions) - *a = make(int16Array, elementCount) - return nil -} - -func (a int16Array) ScanIndex(i int) interface{} { - return &a[i] -} - -func makeArrayGetter(a interface{}) (ArrayGetter, error) { - switch a := a.(type) { - case ArrayGetter: - return a, nil - case []int16: - return (*int16Array)(&a), nil - } - - return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a) -} - -func makeArraySetter(a interface{}) (ArraySetter, error) { - switch a := a.(type) { - case ArraySetter: - return a, nil - case *[]int16: - return (*int16Array)(a), nil - } - - return nil, fmt.Errorf("cannot convert %T to ArraySetter", a) -} - // ArrayCodec is a codec for any array type. type ArrayCodec struct { ElementCodec Codec @@ -155,7 +104,8 @@ func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf return nil, nil } - if len(dimensions) == 0 { + elementCount := cardinality(dimensions) + if elementCount == 0 { return append(buf, '{', '}'), nil } @@ -173,7 +123,6 @@ func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf } inElemBuf := make([]byte, 0, 32) - elementCount := cardinality(dimensions) for i := 0; i < elementCount; i++ { if i > 0 { buf = append(buf, ',') diff --git a/pgtype/int2.go b/pgtype/int2.go index bbfee1cf..b7b7243f 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -2,12 +2,9 @@ package pgtype import ( "database/sql/driver" - "encoding/binary" "fmt" "math" "strconv" - - "github.com/jackc/pgio" ) type Int2 struct { @@ -15,231 +12,6 @@ type Int2 struct { Valid bool } -func (dst *Int2) Set(src interface{}) error { - if src == nil { - *dst = Int2{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case int8: - *dst = Int2{Int: int16(value), Valid: true} - case uint8: - *dst = Int2{Int: int16(value), Valid: true} - case int16: - *dst = Int2{Int: int16(value), Valid: true} - case uint16: - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case int32: - if value < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case uint32: - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case int64: - if value < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case uint64: - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case int: - if value < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case uint: - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case string: - num, err := strconv.ParseInt(value, 10, 16) - if err != nil { - return err - } - *dst = Int2{Int: int16(num), Valid: true} - case float32: - if value > math.MaxInt16 { - return fmt.Errorf("%f is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case float64: - if value > math.MaxInt16 { - return fmt.Errorf("%f is greater than maximum value for Int2", value) - } - *dst = Int2{Int: int16(value), Valid: true} - case *int8: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *uint8: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *int16: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *uint16: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *int32: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *uint32: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *int64: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *uint64: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *int: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *uint: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *string: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *float32: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - case *float64: - if value == nil { - *dst = Int2{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingNumberType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Int2", value) - } - - return nil -} - -func (dst Int2) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.Int -} - -func (src *Int2) AssignTo(dst interface{}) error { - return int64AssignTo(int64(src.Int), src.Valid, dst) -} - -func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int2{} - return nil - } - - n, err := strconv.ParseInt(string(src), 10, 16) - if err != nil { - return err - } - - *dst = Int2{Int: int16(n), Valid: true} - return nil -} - -func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int2{} - return nil - } - - if len(src) != 2 { - return fmt.Errorf("invalid length for int2: %v", len(src)) - } - - n := int16(binary.BigEndian.Uint16(src)) - *dst = Int2{Int: n, Valid: true} - return nil -} - -func (src Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil -} - -func (src Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return pgio.AppendInt16(buf, src.Int), nil -} - // Scan implements the database/sql Scanner interface. func (dst *Int2) Scan(src interface{}) error { if src == nil { @@ -247,25 +19,36 @@ func (dst *Int2) Scan(src interface{}) error { return nil } + var n int64 + switch src := src.(type) { case int64: - if src < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", src) - } - if src > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", src) - } - *dst = Int2{Int: int16(src), Valid: true} - return nil + n = src case string: - return dst.DecodeText(nil, []byte(src)) + var err error + n, err = strconv.ParseInt(src, 10, 16) + if err != nil { + return err + } case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + var err error + n, err = strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + default: + return fmt.Errorf("cannot scan %T", src) } - return fmt.Errorf("cannot scan %T", src) + if n < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + if n > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n) + } + *dst = Int2{Int: int16(n), Valid: true} + + return nil } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go deleted file mode 100644 index d96240dc..00000000 --- a/pgtype/int2_array.go +++ /dev/null @@ -1,896 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type Int2Array struct { - Elements []Int2 - Dimensions []ArrayDimension - Valid bool -} - -func (dst *Int2Array) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = Int2Array{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []int16: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int16: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint16: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint16: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []int32: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int32: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint32: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint32: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []int64: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int64: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint64: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint64: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []int: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*int: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []uint: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*uint: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Int2: - if value == nil { - *dst = Int2Array{} - } else if len(value) == 0 { - *dst = Int2Array{Valid: true} - } else { - *dst = Int2Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = Int2Array{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for Int2Array", src) - } - if elementsLength == 0 { - *dst = Int2Array{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Int2Array", src) - } - - *dst = Int2Array{ - Elements: make([]Int2, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Int2, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to Int2Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *Int2Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to Int2Array") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in Int2Array", err) - } - index++ - - return index, nil -} - -func (dst Int2Array) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *Int2Array) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from Int2Array") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from Int2Array") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int2Array{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Int2 - - if len(uta.Elements) > 0 { - elements = make([]Int2, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Int2 - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Int2Array{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = Int2Array{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Int2, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("int2"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "int2") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Int2Array) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Int2Array) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/int2_codec.go b/pgtype/int2_codec.go index 7ea50870..c50b56d7 100644 --- a/pgtype/int2_codec.go +++ b/pgtype/int2_codec.go @@ -2,6 +2,7 @@ package pgtype import ( "database/sql/driver" + "encoding/binary" "fmt" "math" "strconv" @@ -46,16 +47,31 @@ func (Int2Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{ } func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { case BinaryFormatCode: case TextFormatCode: switch target.(type) { + case *int8: + return scanPlanTextAnyToInt8{} case *int16: - return scanPlanTextToAnyInt16{} + return scanPlanTextAnyToInt16{} case *int32: - return scanPlanTextToAnyInt32{} + return scanPlanTextAnyToInt32{} case *int64: - return scanPlanTextToAnyInt64{} + return scanPlanTextAnyToInt64{} + case *int: + return scanPlanTextAnyToInt{} + case *uint8: + return scanPlanTextAnyToUint8{} + case *uint16: + return scanPlanTextAnyToUint16{} + case *uint32: + return scanPlanTextAnyToUint32{} + case *uint64: + return scanPlanTextAnyToUint64{} + case *uint: + return scanPlanTextAnyToUint{} } } @@ -68,8 +84,15 @@ func (c Int2Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16 } var n int64 - err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) - return n, err + scanPlan := c.PlanScan(ci, oid, format, &n, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil } func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { @@ -78,13 +101,61 @@ func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt } var n int16 - err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n) - return n, err + scanPlan := c.PlanScan(ci, oid, format, &n, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil } -type scanPlanTextToAnyInt16 struct{} +type scanPlanBinaryInt2ToInt16 struct{} -func (scanPlanTextToAnyInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanBinaryInt2ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + p, ok := (dst).(*int16) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = int16(binary.BigEndian.Uint16(src)) + return nil +} + +type scanPlanTextAnyToInt8 struct{} + +func (scanPlanTextAnyToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int8) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 8) + if err != nil { + return err + } + + *p = int8(n) + return nil +} + +type scanPlanTextAnyToInt16 struct{} + +func (scanPlanTextAnyToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -103,9 +174,9 @@ func (scanPlanTextToAnyInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, s return nil } -type scanPlanTextToAnyInt32 struct{} +type scanPlanTextAnyToInt32 struct{} -func (scanPlanTextToAnyInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -124,9 +195,9 @@ func (scanPlanTextToAnyInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, s return nil } -type scanPlanTextToAnyInt64 struct{} +type scanPlanTextAnyToInt64 struct{} -func (scanPlanTextToAnyInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (scanPlanTextAnyToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if src == nil { return fmt.Errorf("cannot scan null into %T", dst) } @@ -144,3 +215,129 @@ func (scanPlanTextToAnyInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, s *p = int64(n) return nil } + +type scanPlanTextAnyToInt struct{} + +func (scanPlanTextAnyToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*int) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseInt(string(src), 10, 0) + if err != nil { + return err + } + + *p = int(n) + return nil +} + +type scanPlanTextAnyToUint8 struct{} + +func (scanPlanTextAnyToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint8) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 8) + if err != nil { + return err + } + + *p = uint8(n) + return nil +} + +type scanPlanTextAnyToUint16 struct{} + +func (scanPlanTextAnyToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint16) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 16) + if err != nil { + return err + } + + *p = uint16(n) + return nil +} + +type scanPlanTextAnyToUint32 struct{} + +func (scanPlanTextAnyToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint32) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *p = uint32(n) + return nil +} + +type scanPlanTextAnyToUint64 struct{} + +func (scanPlanTextAnyToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint64) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 64) + if err != nil { + return err + } + + *p = uint64(n) + return nil +} + +type scanPlanTextAnyToUint struct{} + +func (scanPlanTextAnyToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p, ok := (dst).(*uint) + if !ok { + return ErrScanTargetTypeChanged + } + + n, err := strconv.ParseUint(string(src), 10, 0) + if err != nil { + return err + } + + *p = uint(n) + return nil +} diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go index 58dcd141..f5bdac89 100644 --- a/pgtype/int2_test.go +++ b/pgtype/int2_test.go @@ -1,144 +1,95 @@ package pgtype_test import ( + "context" + "fmt" "math" "reflect" "testing" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestInt2Transcode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ - &pgtype.Int2{Int: math.MinInt16, Valid: true}, - &pgtype.Int2{Int: -1, Valid: true}, - &pgtype.Int2{Int: 0, Valid: true}, - &pgtype.Int2{Int: 1, Valid: true}, - &pgtype.Int2{Int: math.MaxInt16, Valid: true}, - &pgtype.Int2{Int: 0}, +type PgxTranscodeTestCase struct { + src interface{} + dst interface{} + test func(interface{}) bool +} + +func isExpectedEq(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + return a == v + } +} + +func testPgxCodec(t testing.TB, pgTypeName string, tests []PgxTranscodeTestCase) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for i, tt := range tests { + for _, format := range formats { + err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{format.code}, tt.src).Scan(tt.dst) + if err != nil { + t.Errorf("%s %d: %v", format.name, i, err) + } + + dst := reflect.ValueOf(tt.dst) + if dst.Kind() == reflect.Ptr { + dst = dst.Elem() + } + + if !tt.test(dst.Interface()) { + t.Errorf("%s %d: unexpected result for %v: %v", format.name, i, tt.src, dst.Interface()) + } + } + } +} + +func TestInt2Codec(t *testing.T) { + testPgxCodec(t, "int2", []PgxTranscodeTestCase{ + {int8(1), new(int16), isExpectedEq(int16(1))}, + {int16(1), new(int16), isExpectedEq(int16(1))}, + {int32(1), new(int16), isExpectedEq(int16(1))}, + {int64(1), new(int16), isExpectedEq(int16(1))}, + {uint8(1), new(int16), isExpectedEq(int16(1))}, + {uint16(1), new(int16), isExpectedEq(int16(1))}, + {uint32(1), new(int16), isExpectedEq(int16(1))}, + {uint64(1), new(int16), isExpectedEq(int16(1))}, + {int(1), new(int16), isExpectedEq(int16(1))}, + {uint(1), new(int16), isExpectedEq(int16(1))}, + {pgtype.Int2{Int: 1, Valid: true}, new(int16), isExpectedEq(int16(1))}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {math.MinInt16, new(int16), isExpectedEq(int16(math.MinInt16))}, + {-1, new(int16), isExpectedEq(int16(-1))}, + {0, new(int16), isExpectedEq(int16(0))}, + {1, new(int16), isExpectedEq(int16(1))}, + {math.MaxInt16, new(int16), isExpectedEq(int16(math.MaxInt16))}, + {1, new(pgtype.Int2), isExpectedEq(pgtype.Int2{Int: 1, Valid: true})}, + {pgtype.Int2{}, new(pgtype.Int2), isExpectedEq(pgtype.Int2{})}, + {nil, new(*int16), isExpectedEq((*int16)(nil))}, }) } - -func TestInt2Set(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Int2 - }{ - {source: int8(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: int16(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: int32(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: int64(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: int8(-1), result: pgtype.Int2{Int: -1, Valid: true}}, - {source: int16(-1), result: pgtype.Int2{Int: -1, Valid: true}}, - {source: int32(-1), result: pgtype.Int2{Int: -1, Valid: true}}, - {source: int64(-1), result: pgtype.Int2{Int: -1, Valid: true}}, - {source: uint8(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: uint16(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: uint32(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: uint64(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: float32(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: float64(1), result: pgtype.Int2{Int: 1, Valid: true}}, - {source: "1", result: pgtype.Int2{Int: 1, Valid: true}}, - {source: _int8(1), result: pgtype.Int2{Int: 1, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.Int2 - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInt2AssignTo(t *testing.T) { - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - - simpleTests := []struct { - src pgtype.Int2 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i, expected: int(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, - {src: pgtype.Int2{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, - {src: pgtype.Int2{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Int2 - dst interface{} - expected interface{} - }{ - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, - {src: pgtype.Int2{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Int2 - dst interface{} - }{ - {src: pgtype.Int2{Int: 150, Valid: true}, dst: &i8}, - {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui8}, - {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui16}, - {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui32}, - {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui64}, - {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui}, - {src: pgtype.Int2{Int: 0}, dst: &i16}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index b0b07663..4983c3a2 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -300,7 +300,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID}) - ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}}) + ci.RegisterDataType(DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}}) ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID}) ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID}) ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) @@ -324,7 +324,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID}) - ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) + ci.RegisterDataType(DataType{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID}) ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID}) @@ -398,20 +398,10 @@ func NewConnInfo() *ConnInfo { return ci } -func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) { - for name, oid := range nameOIDs { - var value Value - if t, ok := nameValues[name]; ok { - value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value) - } else { - value = &GenericText{} - } - ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid}) - } -} - func (ci *ConnInfo) RegisterDataType(t DataType) { - t.Value = NewValue(t.Value) + if t.Value != nil { + t.Value = NewValue(t.Value) + } ci.oidToDataType[t.OID] = &t ci.nameToDataType[t.Name] = &t @@ -463,8 +453,10 @@ func (ci *ConnInfo) buildReflectTypeToDataType() { ci.reflectTypeToDataType = make(map[reflect.Type]*DataType) for _, dt := range ci.oidToDataType { - if _, is := dt.Value.(TypeValue); !is { - ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt + if dt.Value != nil { + if _, is := dt.Value.(TypeValue); !is { + ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt + } } } @@ -583,8 +575,14 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode } else { switch formatCode { case BinaryFormatCode: + if dt.binaryDecoder == nil { + return fmt.Errorf("dt.binaryDecoder is nil") + } err = dt.binaryDecoder.DecodeBinary(ci, src) case TextFormatCode: + if dt.textDecoder == nil { + return fmt.Errorf("dt.textDecoder is nil") + } err = dt.textDecoder.DecodeText(ci, src) } } @@ -782,14 +780,105 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt return newPlan.Scan(ci, oid, formatCode, src, dst) } +type pointerPointerScanPlan struct { + dstType reflect.Type + next ScanPlan +} + +func (plan *pointerPointerScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if plan.dstType != reflect.TypeOf(dst) { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + + el := reflect.ValueOf(dst).Elem() + if src == nil { + el.Set(reflect.Zero(el.Type())) + return nil + } + + el.Set(reflect.New(el.Type().Elem())) + return plan.next.Scan(ci, oid, formatCode, src, el.Interface()) +} + +func tryPointerPointerScanPlan(dst interface{}) (plan *pointerPointerScanPlan, nextDst interface{}, ok bool) { + if dstValue := reflect.ValueOf(dst); dstValue.Kind() == reflect.Ptr { + elemValue := dstValue.Elem() + if elemValue.Kind() == reflect.Ptr { + plan = &pointerPointerScanPlan{dstType: dstValue.Type()} + return plan, reflect.Zero(elemValue.Type()).Interface(), true + } + } + + return nil, nil, false +} + +var elemKindToBasePointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ + reflect.Int: reflect.TypeOf(new(int)), + reflect.Int8: reflect.TypeOf(new(int8)), + reflect.Int16: reflect.TypeOf(new(int16)), + reflect.Int32: reflect.TypeOf(new(int32)), + reflect.Int64: reflect.TypeOf(new(int64)), + reflect.Uint: reflect.TypeOf(new(uint)), + reflect.Uint8: reflect.TypeOf(new(uint8)), + reflect.Uint16: reflect.TypeOf(new(uint16)), + reflect.Uint32: reflect.TypeOf(new(uint32)), + reflect.Uint64: reflect.TypeOf(new(uint64)), + reflect.Float32: reflect.TypeOf(new(float32)), + reflect.Float64: reflect.TypeOf(new(float64)), + reflect.String: reflect.TypeOf(new(string)), +} + +type baseTypeScanPlan struct { + dstType reflect.Type + nextDstType reflect.Type + next ScanPlan +} + +func (plan *baseTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if plan.dstType != reflect.TypeOf(dst) { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + + return plan.next.Scan(ci, oid, formatCode, src, reflect.ValueOf(dst).Convert(plan.nextDstType).Interface()) +} + +func tryBaseTypeScanPlan(dst interface{}) (plan *baseTypeScanPlan, nextDst interface{}, ok bool) { + dstValue := reflect.ValueOf(dst) + + if dstValue.Kind() == reflect.Ptr { + elemValue := dstValue.Elem() + nextDstType := elemKindToBasePointerTypes[elemValue.Kind()] + if nextDstType != nil { + return &baseTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true + } + } + + return nil, nil, false +} + // PlanScan prepares a plan to scan a value into dst. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { if oid != 0 { if dt, ok := ci.DataTypeForOID(oid); ok && dt.Codec != nil { - plan := dt.Codec.PlanScan(ci, oid, formatCode, dst, false) - if plan != nil { + if plan := dt.Codec.PlanScan(ci, oid, formatCode, dst, false); plan != nil { return plan } + + if pointerPointerPlan, nextDst, ok := tryPointerPointerScanPlan(dst); ok { + if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { + pointerPointerPlan.next = nextPlan + return pointerPointerPlan + } + } + + if baseTypePlan, nextDst, ok := tryBaseTypeScanPlan(dst); ok { + if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { + baseTypePlan.next = nextPlan + return baseTypePlan + } + } } } @@ -908,77 +997,3 @@ func NewValue(v Value) Value { } var ErrScanTargetTypeChanged = errors.New("scan target type changed") - -var nameValues map[string]Value - -func init() { - nameValues = map[string]Value{ - "_aclitem": &ACLItemArray{}, - "_bool": &BoolArray{}, - "_bpchar": &BPCharArray{}, - "_bytea": &ByteaArray{}, - "_cidr": &CIDRArray{}, - "_date": &DateArray{}, - "_float4": &Float4Array{}, - "_float8": &Float8Array{}, - "_inet": &InetArray{}, - "_int2": &Int2Array{}, - "_int4": &Int4Array{}, - "_int8": &Int8Array{}, - "_numeric": &NumericArray{}, - "_text": &TextArray{}, - "_timestamp": &TimestampArray{}, - "_timestamptz": &TimestamptzArray{}, - "_uuid": &UUIDArray{}, - "_varchar": &VarcharArray{}, - "_jsonb": &JSONBArray{}, - "aclitem": &ACLItem{}, - "bit": &Bit{}, - "bool": &Bool{}, - "box": &Box{}, - "bpchar": &BPChar{}, - "bytea": &Bytea{}, - "char": &QChar{}, - "cid": &CID{}, - "cidr": &CIDR{}, - "circle": &Circle{}, - "date": &Date{}, - "daterange": &Daterange{}, - "float4": &Float4{}, - "float8": &Float8{}, - "hstore": &Hstore{}, - "inet": &Inet{}, - "int2": &Int2{}, - "int4": &Int4{}, - "int4range": &Int4range{}, - "int8": &Int8{}, - "int8range": &Int8range{}, - "interval": &Interval{}, - "json": &JSON{}, - "jsonb": &JSONB{}, - "line": &Line{}, - "lseg": &Lseg{}, - "macaddr": &Macaddr{}, - "name": &Name{}, - "numeric": &Numeric{}, - "numrange": &Numrange{}, - "oid": &OIDValue{}, - "path": &Path{}, - "point": &Point{}, - "polygon": &Polygon{}, - "record": &Record{}, - "text": &Text{}, - "tid": &TID{}, - "timestamp": &Timestamp{}, - "timestamptz": &Timestamptz{}, - "tsrange": &Tsrange{}, - "_tsrange": &TsrangeArray{}, - "tstzrange": &Tstzrange{}, - "_tstzrange": &TstzrangeArray{}, - "unknown": &Unknown{}, - "uuid": &UUID{}, - "varbit": &Varbit{}, - "varchar": &Varchar{}, - "xid": &XID{}, - } -} diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index ea28be07..ae0e67cb 100755 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -1,4 +1,3 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go diff --git a/pgtype/zzz.int2.go b/pgtype/zzz.int2.go deleted file mode 100644 index f2d959f9..00000000 --- a/pgtype/zzz.int2.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Int2) BinaryFormatSupported() bool { - return true -} - -func (Int2) TextFormatSupported() bool { - return true -} - -func (Int2) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Int2) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Int2) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -} diff --git a/query_test.go b/query_test.go index e725bd40..d9b35e28 100644 --- a/query_test.go +++ b/query_test.go @@ -920,65 +920,64 @@ func TestQueryRowCoreIntegerDecoding(t *testing.T) { } failedDecodeTests := []struct { - sql string - scanArg interface{} - expectedErr string + sql string + scanArg interface{} }{ // Check any integer type where value is outside Go:int8 range cannot be decoded - {"select 128::int2", &actual.i8, "is greater than"}, - {"select 128::int4", &actual.i8, "is greater than"}, - {"select 128::int8", &actual.i8, "is greater than"}, - {"select -129::int2", &actual.i8, "is less than"}, - {"select -129::int4", &actual.i8, "is less than"}, - {"select -129::int8", &actual.i8, "is less than"}, + {"select 128::int2", &actual.i8}, + {"select 128::int4", &actual.i8}, + {"select 128::int8", &actual.i8}, + {"select -129::int2", &actual.i8}, + {"select -129::int4", &actual.i8}, + {"select -129::int8", &actual.i8}, // Check any integer type where value is outside Go:int16 range cannot be decoded - {"select 32768::int4", &actual.i16, "is greater than"}, - {"select 32768::int8", &actual.i16, "is greater than"}, - {"select -32769::int4", &actual.i16, "is less than"}, - {"select -32769::int8", &actual.i16, "is less than"}, + {"select 32768::int4", &actual.i16}, + {"select 32768::int8", &actual.i16}, + {"select -32769::int4", &actual.i16}, + {"select -32769::int8", &actual.i16}, // Check any integer type where value is outside Go:int32 range cannot be decoded - {"select 2147483648::int8", &actual.i32, "is greater than"}, - {"select -2147483649::int8", &actual.i32, "is less than"}, + {"select 2147483648::int8", &actual.i32}, + {"select -2147483649::int8", &actual.i32}, // Check any integer type where value is outside Go:uint range cannot be decoded - {"select -1::int2", &actual.ui, "is less than"}, - {"select -1::int4", &actual.ui, "is less than"}, - {"select -1::int8", &actual.ui, "is less than"}, + {"select -1::int2", &actual.ui}, + {"select -1::int4", &actual.ui}, + {"select -1::int8", &actual.ui}, // Check any integer type where value is outside Go:uint8 range cannot be decoded - {"select 256::int2", &actual.ui8, "is greater than"}, - {"select 256::int4", &actual.ui8, "is greater than"}, - {"select 256::int8", &actual.ui8, "is greater than"}, - {"select -1::int2", &actual.ui8, "is less than"}, - {"select -1::int4", &actual.ui8, "is less than"}, - {"select -1::int8", &actual.ui8, "is less than"}, + {"select 256::int2", &actual.ui8}, + {"select 256::int4", &actual.ui8}, + {"select 256::int8", &actual.ui8}, + {"select -1::int2", &actual.ui8}, + {"select -1::int4", &actual.ui8}, + {"select -1::int8", &actual.ui8}, // Check any integer type where value is outside Go:uint16 cannot be decoded - {"select 65536::int4", &actual.ui16, "is greater than"}, - {"select 65536::int8", &actual.ui16, "is greater than"}, - {"select -1::int2", &actual.ui16, "is less than"}, - {"select -1::int4", &actual.ui16, "is less than"}, - {"select -1::int8", &actual.ui16, "is less than"}, + {"select 65536::int4", &actual.ui16}, + {"select 65536::int8", &actual.ui16}, + {"select -1::int2", &actual.ui16}, + {"select -1::int4", &actual.ui16}, + {"select -1::int8", &actual.ui16}, // Check any integer type where value is outside Go:uint32 range cannot be decoded - {"select 4294967296::int8", &actual.ui32, "is greater than"}, - {"select -1::int2", &actual.ui32, "is less than"}, - {"select -1::int4", &actual.ui32, "is less than"}, - {"select -1::int8", &actual.ui32, "is less than"}, + {"select 4294967296::int8", &actual.ui32}, + {"select -1::int2", &actual.ui32}, + {"select -1::int4", &actual.ui32}, + {"select -1::int8", &actual.ui32}, // Check any integer type where value is outside Go:uint64 range cannot be decoded - {"select -1::int2", &actual.ui64, "is less than"}, - {"select -1::int4", &actual.ui64, "is less than"}, - {"select -1::int8", &actual.ui64, "is less than"}, + {"select -1::int2", &actual.ui64}, + {"select -1::int4", &actual.ui64}, + {"select -1::int8", &actual.ui64}, } for i, tt := range failedDecodeTests { err := conn.QueryRow(context.Background(), tt.sql).Scan(tt.scanArg) if err == nil { t.Errorf("%d. Expected failure to decode, but unexpectedly succeeded: %v (sql -> %v)", i, err, tt.sql) - } else if !strings.Contains(err.Error(), tt.expectedErr) { + } else if !strings.Contains(err.Error(), "can't scan") { t.Errorf("%d. Expected failure to decode, but got: %v (sql -> %v)", i, err, tt.sql) } diff --git a/rows.go b/rows.go index cc6e26d5..0cc09ad9 100644 --- a/rows.go +++ b/rows.go @@ -246,31 +246,40 @@ func (rows *connRows) Values() ([]interface{}, error) { } if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok { - value := dt.Value + if dt.Value != nil { - switch fd.Format { - case TextFormatCode: - decoder, ok := value.(pgtype.TextDecoder) - if !ok { - decoder = &pgtype.GenericText{} + value := dt.Value + + switch fd.Format { + case TextFormatCode: + decoder, ok := value.(pgtype.TextDecoder) + if !ok { + decoder = &pgtype.GenericText{} + } + err := decoder.DecodeText(rows.connInfo, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, decoder.(pgtype.Value).Get()) + case BinaryFormatCode: + decoder, ok := value.(pgtype.BinaryDecoder) + if !ok { + decoder = &pgtype.GenericBinary{} + } + err := decoder.DecodeBinary(rows.connInfo, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, value.Get()) + default: + rows.fatal(errors.New("Unknown format code")) } - err := decoder.DecodeText(rows.connInfo, buf) + } else if dt.Codec != nil { + value, err := dt.Codec.DecodeValue(rows.connInfo, fd.DataTypeOID, fd.Format, buf) if err != nil { rows.fatal(err) } - values = append(values, decoder.(pgtype.Value).Get()) - case BinaryFormatCode: - decoder, ok := value.(pgtype.BinaryDecoder) - if !ok { - decoder = &pgtype.GenericBinary{} - } - err := decoder.DecodeBinary(rows.connInfo, buf) - if err != nil { - rows.fatal(err) - } - values = append(values, value.Get()) - default: - rows.fatal(errors.New("Unknown format code")) + values = append(values, value) } } else { switch fd.Format { diff --git a/values.go b/values.go index 2f328b82..00606689 100644 --- a/values.go +++ b/values.go @@ -115,19 +115,30 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e } if dt, found := ci.DataTypeForValue(arg); found { - v := dt.Value - err := v.Set(arg) - if err != nil { - return nil, err + if dt.Value != nil { + v := dt.Value + err := v.Set(arg) + if err != nil { + return nil, err + } + buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), nil + } else if dt.Codec != nil { + buf, err := dt.Codec.Encode(ci, 0, TextFormatCode, arg, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), nil } - buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil } if refVal.Kind() == reflect.Ptr { @@ -188,33 +199,47 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 } if dt, ok := ci.DataTypeForOID(oid); ok { - value := dt.Value - err := value.Set(arg) - if err != nil { - { - if arg, ok := arg.(driver.Valuer); ok { - v, err := callValuerValue(arg) - if err != nil { - return nil, err + if dt.Value != nil { + value := dt.Value + err := value.Set(arg) + if err != nil { + { + if arg, ok := arg.(driver.Valuer); ok { + v, err := callValuerValue(arg) + if err != nil { + return nil, err + } + return encodePreparedStatementArgument(ci, buf, oid, v) } - return encodePreparedStatementArgument(ci, buf, oid, v) } + + return nil, err } - return nil, err + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil + } else if dt.Codec != nil { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := dt.Codec.Encode(ci, oid, BinaryFormatCode, arg, buf) + if err != nil { + return nil, err + } + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil } - - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil } if strippedArg, ok := stripNamedType(&refVal); ok {