2
0

Initial passing tests for main pgx package

This commit is contained in:
Jack Christensen
2021-12-30 18:12:47 -06:00
parent 58b7486343
commit 9fc8f9b3a8
12 changed files with 574 additions and 1567 deletions
+24 -13
View File
@@ -113,23 +113,34 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o
}
if dt, ok := ci.DataTypeForOID(oid); ok {
value := dt.Value
err := value.Set(arg)
if err != nil {
{
if arg, ok := arg.(driver.Valuer); ok {
v, err := callValuerValue(arg)
if err != nil {
return nil, err
if dt.Value != nil {
value := dt.Value
err := value.Set(arg)
if err != nil {
{
if arg, ok := arg.(driver.Valuer); ok {
v, err := callValuerValue(arg)
if err != nil {
return nil, err
}
return eqb.encodeExtendedParamValue(ci, oid, formatCode, v)
}
return eqb.encodeExtendedParamValue(ci, oid, formatCode, v)
}
return nil, err
}
return nil, err
return eqb.encodeExtendedParamValue(ci, oid, formatCode, value)
} else if dt.Codec != nil {
buf, err := dt.Codec.Encode(ci, oid, formatCode, arg, eqb.paramValueBytes)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
eqb.paramValueBytes = buf
return eqb.paramValueBytes[pos:], nil
}
return eqb.encodeExtendedParamValue(ci, oid, formatCode, value)
}
// There is no data type registered for the destination OID, but maybe there is data type registered for the arg
+2 -53
View File
@@ -28,57 +28,6 @@ type ArraySetter interface {
ScanIndex(i int) interface{}
}
type int16Array []int16
func (a int16Array) Dimensions() []ArrayDimension {
if a == nil {
return nil
}
return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}}
}
func (a int16Array) Index(i int) interface{} {
return a[i]
}
func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error {
if dimensions == nil {
a = nil
return nil
}
elementCount := cardinality(dimensions)
*a = make(int16Array, elementCount)
return nil
}
func (a int16Array) ScanIndex(i int) interface{} {
return &a[i]
}
func makeArrayGetter(a interface{}) (ArrayGetter, error) {
switch a := a.(type) {
case ArrayGetter:
return a, nil
case []int16:
return (*int16Array)(&a), nil
}
return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a)
}
func makeArraySetter(a interface{}) (ArraySetter, error) {
switch a := a.(type) {
case ArraySetter:
return a, nil
case *[]int16:
return (*int16Array)(a), nil
}
return nil, fmt.Errorf("cannot convert %T to ArraySetter", a)
}
// ArrayCodec is a codec for any array type.
type ArrayCodec struct {
ElementCodec Codec
@@ -155,7 +104,8 @@ func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf
return nil, nil
}
if len(dimensions) == 0 {
elementCount := cardinality(dimensions)
if elementCount == 0 {
return append(buf, '{', '}'), nil
}
@@ -173,7 +123,6 @@ func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf
}
inElemBuf := make([]byte, 0, 32)
elementCount := cardinality(dimensions)
for i := 0; i < elementCount; i++ {
if i > 0 {
buf = append(buf, ',')
+24 -241
View File
@@ -2,12 +2,9 @@ package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"math"
"strconv"
"github.com/jackc/pgio"
)
type Int2 struct {
@@ -15,231 +12,6 @@ type Int2 struct {
Valid bool
}
func (dst *Int2) Set(src interface{}) error {
if src == nil {
*dst = Int2{}
return nil
}
if value, ok := src.(interface{ Get() interface{} }); ok {
value2 := value.Get()
if value2 != value {
return dst.Set(value2)
}
}
switch value := src.(type) {
case int8:
*dst = Int2{Int: int16(value), Valid: true}
case uint8:
*dst = Int2{Int: int16(value), Valid: true}
case int16:
*dst = Int2{Int: int16(value), Valid: true}
case uint16:
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*dst = Int2{Int: int16(value), Valid: true}
case int32:
if value < math.MinInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*dst = Int2{Int: int16(value), Valid: true}
case uint32:
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*dst = Int2{Int: int16(value), Valid: true}
case int64:
if value < math.MinInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*dst = Int2{Int: int16(value), Valid: true}
case uint64:
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*dst = Int2{Int: int16(value), Valid: true}
case int:
if value < math.MinInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*dst = Int2{Int: int16(value), Valid: true}
case uint:
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", value)
}
*dst = Int2{Int: int16(value), Valid: true}
case string:
num, err := strconv.ParseInt(value, 10, 16)
if err != nil {
return err
}
*dst = Int2{Int: int16(num), Valid: true}
case float32:
if value > math.MaxInt16 {
return fmt.Errorf("%f is greater than maximum value for Int2", value)
}
*dst = Int2{Int: int16(value), Valid: true}
case float64:
if value > math.MaxInt16 {
return fmt.Errorf("%f is greater than maximum value for Int2", value)
}
*dst = Int2{Int: int16(value), Valid: true}
case *int8:
if value == nil {
*dst = Int2{}
} else {
return dst.Set(*value)
}
case *uint8:
if value == nil {
*dst = Int2{}
} else {
return dst.Set(*value)
}
case *int16:
if value == nil {
*dst = Int2{}
} else {
return dst.Set(*value)
}
case *uint16:
if value == nil {
*dst = Int2{}
} else {
return dst.Set(*value)
}
case *int32:
if value == nil {
*dst = Int2{}
} else {
return dst.Set(*value)
}
case *uint32:
if value == nil {
*dst = Int2{}
} else {
return dst.Set(*value)
}
case *int64:
if value == nil {
*dst = Int2{}
} else {
return dst.Set(*value)
}
case *uint64:
if value == nil {
*dst = Int2{}
} else {
return dst.Set(*value)
}
case *int:
if value == nil {
*dst = Int2{}
} else {
return dst.Set(*value)
}
case *uint:
if value == nil {
*dst = Int2{}
} else {
return dst.Set(*value)
}
case *string:
if value == nil {
*dst = Int2{}
} else {
return dst.Set(*value)
}
case *float32:
if value == nil {
*dst = Int2{}
} else {
return dst.Set(*value)
}
case *float64:
if value == nil {
*dst = Int2{}
} else {
return dst.Set(*value)
}
default:
if originalSrc, ok := underlyingNumberType(src); ok {
return dst.Set(originalSrc)
}
return fmt.Errorf("cannot convert %v to Int2", value)
}
return nil
}
func (dst Int2) Get() interface{} {
if !dst.Valid {
return nil
}
return dst.Int
}
func (src *Int2) AssignTo(dst interface{}) error {
return int64AssignTo(int64(src.Int), src.Valid, dst)
}
func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Int2{}
return nil
}
n, err := strconv.ParseInt(string(src), 10, 16)
if err != nil {
return err
}
*dst = Int2{Int: int16(n), Valid: true}
return nil
}
func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Int2{}
return nil
}
if len(src) != 2 {
return fmt.Errorf("invalid length for int2: %v", len(src))
}
n := int16(binary.BigEndian.Uint16(src))
*dst = Int2{Int: n, Valid: true}
return nil
}
func (src Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
if !src.Valid {
return nil, nil
}
return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil
}
func (src Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
if !src.Valid {
return nil, nil
}
return pgio.AppendInt16(buf, src.Int), nil
}
// Scan implements the database/sql Scanner interface.
func (dst *Int2) Scan(src interface{}) error {
if src == nil {
@@ -247,25 +19,36 @@ func (dst *Int2) Scan(src interface{}) error {
return nil
}
var n int64
switch src := src.(type) {
case int64:
if src < math.MinInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", src)
}
if src > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", src)
}
*dst = Int2{Int: int16(src), Valid: true}
return nil
n = src
case string:
return dst.DecodeText(nil, []byte(src))
var err error
n, err = strconv.ParseInt(src, 10, 16)
if err != nil {
return err
}
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
var err error
n, err = strconv.ParseInt(string(src), 10, 16)
if err != nil {
return err
}
default:
return fmt.Errorf("cannot scan %T", src)
}
return fmt.Errorf("cannot scan %T", src)
if n < math.MinInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", n)
}
if n > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", n)
}
*dst = Int2{Int: int16(n), Valid: true}
return nil
}
// Value implements the database/sql/driver Valuer interface.
-896
View File
@@ -1,896 +0,0 @@
// Code generated by erb. DO NOT EDIT.
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"reflect"
"github.com/jackc/pgio"
)
type Int2Array struct {
Elements []Int2
Dimensions []ArrayDimension
Valid bool
}
func (dst *Int2Array) Set(src interface{}) error {
// untyped nil and typed nil interfaces are different
if src == nil {
*dst = Int2Array{}
return nil
}
if value, ok := src.(interface{ Get() interface{} }); ok {
value2 := value.Get()
if value2 != value {
return dst.Set(value2)
}
}
// Attempt to match to select common types:
switch value := src.(type) {
case []int16:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []*int16:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []uint16:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []*uint16:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []int32:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []*int32:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []uint32:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []*uint32:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []int64:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []*int64:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []uint64:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []*uint64:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []int:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []*int:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []uint:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []*uint:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
elements := make([]Int2, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int2Array{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Valid: true,
}
}
case []Int2:
if value == nil {
*dst = Int2Array{}
} else if len(value) == 0 {
*dst = Int2Array{Valid: true}
} else {
*dst = Int2Array{
Elements: value,
Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}},
Valid: true,
}
}
default:
// Fallback to reflection if an optimised match was not found.
// The reflection is necessary for arrays and multidimensional slices,
// but it comes with a 20-50% performance penalty for large arrays/slices
reflectedValue := reflect.ValueOf(src)
if !reflectedValue.IsValid() || reflectedValue.IsZero() {
*dst = Int2Array{}
return nil
}
dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0)
if !ok {
return fmt.Errorf("cannot find dimensions of %v for Int2Array", src)
}
if elementsLength == 0 {
*dst = Int2Array{Valid: true}
return nil
}
if len(dimensions) == 0 {
if originalSrc, ok := underlyingSliceType(src); ok {
return dst.Set(originalSrc)
}
return fmt.Errorf("cannot convert %v to Int2Array", src)
}
*dst = Int2Array{
Elements: make([]Int2, elementsLength),
Dimensions: dimensions,
Valid: true,
}
elementCount, err := dst.setRecursive(reflectedValue, 0, 0)
if err != nil {
// Maybe the target was one dimension too far, try again:
if len(dst.Dimensions) > 1 {
dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1]
elementsLength = 0
for _, dim := range dst.Dimensions {
if elementsLength == 0 {
elementsLength = int(dim.Length)
} else {
elementsLength *= int(dim.Length)
}
}
dst.Elements = make([]Int2, elementsLength)
elementCount, err = dst.setRecursive(reflectedValue, 0, 0)
if err != nil {
return err
}
} else {
return err
}
}
if elementCount != len(dst.Elements) {
return fmt.Errorf("cannot convert %v to Int2Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount)
}
}
return nil
}
func (dst *Int2Array) setRecursive(value reflect.Value, index, dimension int) (int, error) {
switch value.Kind() {
case reflect.Array:
fallthrough
case reflect.Slice:
if len(dst.Dimensions) == dimension {
break
}
valueLen := value.Len()
if int32(valueLen) != dst.Dimensions[dimension].Length {
return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions")
}
for i := 0; i < valueLen; i++ {
var err error
index, err = dst.setRecursive(value.Index(i), index, dimension+1)
if err != nil {
return 0, err
}
}
return index, nil
}
if !value.CanInterface() {
return 0, fmt.Errorf("cannot convert all values to Int2Array")
}
if err := dst.Elements[index].Set(value.Interface()); err != nil {
return 0, fmt.Errorf("%v in Int2Array", err)
}
index++
return index, nil
}
func (dst Int2Array) Get() interface{} {
if !dst.Valid {
return nil
}
return dst
}
func (src *Int2Array) AssignTo(dst interface{}) error {
if !src.Valid {
return NullAssignTo(dst)
}
if len(src.Dimensions) <= 1 {
// Attempt to match to select common types:
switch v := dst.(type) {
case *[]int16:
*v = make([]int16, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]*int16:
*v = make([]*int16, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]uint16:
*v = make([]uint16, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]*uint16:
*v = make([]*uint16, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]int32:
*v = make([]int32, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]*int32:
*v = make([]*int32, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]uint32:
*v = make([]uint32, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]*uint32:
*v = make([]*uint32, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]int64:
*v = make([]int64, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]*int64:
*v = make([]*int64, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]uint64:
*v = make([]uint64, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]*uint64:
*v = make([]*uint64, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]int:
*v = make([]int, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]*int:
*v = make([]*int, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]uint:
*v = make([]uint, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]*uint:
*v = make([]*uint, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
}
}
// Try to convert to something AssignTo can use directly.
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
// Fallback to reflection if an optimised match was not found.
// The reflection is necessary for arrays and multidimensional slices,
// but it comes with a 20-50% performance penalty for large arrays/slices
value := reflect.ValueOf(dst)
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
switch value.Kind() {
case reflect.Array, reflect.Slice:
default:
return fmt.Errorf("cannot assign %T to %T", src, dst)
}
if len(src.Elements) == 0 {
if value.Kind() == reflect.Slice {
value.Set(reflect.MakeSlice(value.Type(), 0, 0))
return nil
}
}
elementCount, err := src.assignToRecursive(value, 0, 0)
if err != nil {
return err
}
if elementCount != len(src.Elements) {
return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount)
}
return nil
}
func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) {
switch kind := value.Kind(); kind {
case reflect.Array:
fallthrough
case reflect.Slice:
if len(src.Dimensions) == dimension {
break
}
length := int(src.Dimensions[dimension].Length)
if reflect.Array == kind {
typ := value.Type()
if typ.Len() != length {
return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len())
}
value.Set(reflect.New(typ).Elem())
} else {
value.Set(reflect.MakeSlice(value.Type(), length, length))
}
var err error
for i := 0; i < length; i++ {
index, err = src.assignToRecursive(value.Index(i), index, dimension+1)
if err != nil {
return 0, err
}
}
return index, nil
}
if len(src.Dimensions) != dimension {
return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension)
}
if !value.CanAddr() {
return 0, fmt.Errorf("cannot assign all values from Int2Array")
}
addr := value.Addr()
if !addr.CanInterface() {
return 0, fmt.Errorf("cannot assign all values from Int2Array")
}
if err := src.Elements[index].AssignTo(addr.Interface()); err != nil {
return 0, err
}
index++
return index, nil
}
func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Int2Array{}
return nil
}
uta, err := ParseUntypedTextArray(string(src))
if err != nil {
return err
}
var elements []Int2
if len(uta.Elements) > 0 {
elements = make([]Int2, len(uta.Elements))
for i, s := range uta.Elements {
var elem Int2
var elemSrc []byte
if s != "NULL" || uta.Quoted[i] {
elemSrc = []byte(s)
}
err = elem.DecodeText(ci, elemSrc)
if err != nil {
return err
}
elements[i] = elem
}
}
*dst = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true}
return nil
}
func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Int2Array{}
return nil
}
var arrayHeader ArrayHeader
rp, err := arrayHeader.DecodeBinary(ci, src)
if err != nil {
return err
}
if len(arrayHeader.Dimensions) == 0 {
*dst = Int2Array{Dimensions: arrayHeader.Dimensions, Valid: true}
return nil
}
elementCount := arrayHeader.Dimensions[0].Length
for _, d := range arrayHeader.Dimensions[1:] {
elementCount *= d.Length
}
elements := make([]Int2, elementCount)
for i := range elements {
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
var elemSrc []byte
if elemLen >= 0 {
elemSrc = src[rp : rp+elemLen]
rp += elemLen
}
err = elements[i].DecodeBinary(ci, elemSrc)
if err != nil {
return err
}
}
*dst = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true}
return nil
}
func (src Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
if !src.Valid {
return nil, nil
}
if len(src.Dimensions) == 0 {
return append(buf, '{', '}'), nil
}
buf = EncodeTextArrayDimensions(buf, src.Dimensions)
// dimElemCounts is the multiples of elements that each array lies on. For
// example, a single dimension array of length 4 would have a dimElemCounts of
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
// or '}'.
dimElemCounts := make([]int, len(src.Dimensions))
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
for i := len(src.Dimensions) - 2; i > -1; i-- {
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
}
inElemBuf := make([]byte, 0, 32)
for i, elem := range src.Elements {
if i > 0 {
buf = append(buf, ',')
}
for _, dec := range dimElemCounts {
if i%dec == 0 {
buf = append(buf, '{')
}
}
elemBuf, err := elem.EncodeText(ci, inElemBuf)
if err != nil {
return nil, err
}
if elemBuf == nil {
buf = append(buf, `NULL`...)
} else {
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
}
for _, dec := range dimElemCounts {
if (i+1)%dec == 0 {
buf = append(buf, '}')
}
}
}
return buf, nil
}
func (src Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
if !src.Valid {
return nil, nil
}
arrayHeader := ArrayHeader{
Dimensions: src.Dimensions,
}
if dt, ok := ci.DataTypeForName("int2"); ok {
arrayHeader.ElementOID = int32(dt.OID)
} else {
return nil, fmt.Errorf("unable to find oid for type name %v", "int2")
}
for i := range src.Elements {
if !src.Elements[i].Valid {
arrayHeader.ContainsNull = true
break
}
}
buf = arrayHeader.EncodeBinary(ci, buf)
for i := range src.Elements {
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
elemBuf, err := src.Elements[i].EncodeBinary(ci, buf)
if err != nil {
return nil, err
}
if elemBuf != nil {
buf = elemBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
}
return buf, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *Int2Array) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Int2Array) Value() (driver.Value, error) {
buf, err := src.EncodeText(nil, nil)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
return string(buf), nil
}
+210 -13
View File
@@ -2,6 +2,7 @@ package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"math"
"strconv"
@@ -46,16 +47,31 @@ func (Int2Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{
}
func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan {
switch format {
case BinaryFormatCode:
case TextFormatCode:
switch target.(type) {
case *int8:
return scanPlanTextAnyToInt8{}
case *int16:
return scanPlanTextToAnyInt16{}
return scanPlanTextAnyToInt16{}
case *int32:
return scanPlanTextToAnyInt32{}
return scanPlanTextAnyToInt32{}
case *int64:
return scanPlanTextToAnyInt64{}
return scanPlanTextAnyToInt64{}
case *int:
return scanPlanTextAnyToInt{}
case *uint8:
return scanPlanTextAnyToUint8{}
case *uint16:
return scanPlanTextAnyToUint16{}
case *uint32:
return scanPlanTextAnyToUint32{}
case *uint64:
return scanPlanTextAnyToUint64{}
case *uint:
return scanPlanTextAnyToUint{}
}
}
@@ -68,8 +84,15 @@ func (c Int2Codec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16
}
var n int64
err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n)
return n, err
scanPlan := c.PlanScan(ci, oid, format, &n, true)
if scanPlan == nil {
return nil, fmt.Errorf("PlanScan did not find a plan")
}
err := scanPlan.Scan(ci, oid, format, src, &n)
if err != nil {
return nil, err
}
return n, nil
}
func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) {
@@ -78,13 +101,61 @@ func (c Int2Codec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byt
}
var n int16
err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n)
return n, err
scanPlan := c.PlanScan(ci, oid, format, &n, true)
if scanPlan == nil {
return nil, fmt.Errorf("PlanScan did not find a plan")
}
err := scanPlan.Scan(ci, oid, format, src, &n)
if err != nil {
return nil, err
}
return n, nil
}
type scanPlanTextToAnyInt16 struct{}
type scanPlanBinaryInt2ToInt16 struct{}
func (scanPlanTextToAnyInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
func (scanPlanBinaryInt2ToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if src == nil {
return fmt.Errorf("cannot scan null into %T", dst)
}
if len(src) != 2 {
return fmt.Errorf("invalid length for int2: %v", len(src))
}
p, ok := (dst).(*int16)
if !ok {
return ErrScanTargetTypeChanged
}
*p = int16(binary.BigEndian.Uint16(src))
return nil
}
type scanPlanTextAnyToInt8 struct{}
func (scanPlanTextAnyToInt8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if src == nil {
return fmt.Errorf("cannot scan null into %T", dst)
}
p, ok := (dst).(*int8)
if !ok {
return ErrScanTargetTypeChanged
}
n, err := strconv.ParseInt(string(src), 10, 8)
if err != nil {
return err
}
*p = int8(n)
return nil
}
type scanPlanTextAnyToInt16 struct{}
func (scanPlanTextAnyToInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if src == nil {
return fmt.Errorf("cannot scan null into %T", dst)
}
@@ -103,9 +174,9 @@ func (scanPlanTextToAnyInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, s
return nil
}
type scanPlanTextToAnyInt32 struct{}
type scanPlanTextAnyToInt32 struct{}
func (scanPlanTextToAnyInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
func (scanPlanTextAnyToInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if src == nil {
return fmt.Errorf("cannot scan null into %T", dst)
}
@@ -124,9 +195,9 @@ func (scanPlanTextToAnyInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, s
return nil
}
type scanPlanTextToAnyInt64 struct{}
type scanPlanTextAnyToInt64 struct{}
func (scanPlanTextToAnyInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
func (scanPlanTextAnyToInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if src == nil {
return fmt.Errorf("cannot scan null into %T", dst)
}
@@ -144,3 +215,129 @@ func (scanPlanTextToAnyInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, s
*p = int64(n)
return nil
}
type scanPlanTextAnyToInt struct{}
func (scanPlanTextAnyToInt) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if src == nil {
return fmt.Errorf("cannot scan null into %T", dst)
}
p, ok := (dst).(*int)
if !ok {
return ErrScanTargetTypeChanged
}
n, err := strconv.ParseInt(string(src), 10, 0)
if err != nil {
return err
}
*p = int(n)
return nil
}
type scanPlanTextAnyToUint8 struct{}
func (scanPlanTextAnyToUint8) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if src == nil {
return fmt.Errorf("cannot scan null into %T", dst)
}
p, ok := (dst).(*uint8)
if !ok {
return ErrScanTargetTypeChanged
}
n, err := strconv.ParseUint(string(src), 10, 8)
if err != nil {
return err
}
*p = uint8(n)
return nil
}
type scanPlanTextAnyToUint16 struct{}
func (scanPlanTextAnyToUint16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if src == nil {
return fmt.Errorf("cannot scan null into %T", dst)
}
p, ok := (dst).(*uint16)
if !ok {
return ErrScanTargetTypeChanged
}
n, err := strconv.ParseUint(string(src), 10, 16)
if err != nil {
return err
}
*p = uint16(n)
return nil
}
type scanPlanTextAnyToUint32 struct{}
func (scanPlanTextAnyToUint32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if src == nil {
return fmt.Errorf("cannot scan null into %T", dst)
}
p, ok := (dst).(*uint32)
if !ok {
return ErrScanTargetTypeChanged
}
n, err := strconv.ParseUint(string(src), 10, 32)
if err != nil {
return err
}
*p = uint32(n)
return nil
}
type scanPlanTextAnyToUint64 struct{}
func (scanPlanTextAnyToUint64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if src == nil {
return fmt.Errorf("cannot scan null into %T", dst)
}
p, ok := (dst).(*uint64)
if !ok {
return ErrScanTargetTypeChanged
}
n, err := strconv.ParseUint(string(src), 10, 64)
if err != nil {
return err
}
*p = uint64(n)
return nil
}
type scanPlanTextAnyToUint struct{}
func (scanPlanTextAnyToUint) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if src == nil {
return fmt.Errorf("cannot scan null into %T", dst)
}
p, ok := (dst).(*uint)
if !ok {
return ErrScanTargetTypeChanged
}
n, err := strconv.ParseUint(string(src), 10, 0)
if err != nil {
return err
}
*p = uint(n)
return nil
}
+82 -131
View File
@@ -1,144 +1,95 @@
package pgtype_test
import (
"context"
"fmt"
"math"
"reflect"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgtype/testutil"
)
func TestInt2Transcode(t *testing.T) {
testutil.TestSuccessfulTranscode(t, "int2", []interface{}{
&pgtype.Int2{Int: math.MinInt16, Valid: true},
&pgtype.Int2{Int: -1, Valid: true},
&pgtype.Int2{Int: 0, Valid: true},
&pgtype.Int2{Int: 1, Valid: true},
&pgtype.Int2{Int: math.MaxInt16, Valid: true},
&pgtype.Int2{Int: 0},
type PgxTranscodeTestCase struct {
src interface{}
dst interface{}
test func(interface{}) bool
}
func isExpectedEq(a interface{}) func(interface{}) bool {
return func(v interface{}) bool {
return a == v
}
}
func testPgxCodec(t testing.TB, pgTypeName string, tests []PgxTranscodeTestCase) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
_, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName))
if err != nil {
t.Fatal(err)
}
formats := []struct {
name string
code int16
}{
{name: "TextFormat", code: pgx.TextFormatCode},
{name: "BinaryFormat", code: pgx.BinaryFormatCode},
}
for i, tt := range tests {
for _, format := range formats {
err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{format.code}, tt.src).Scan(tt.dst)
if err != nil {
t.Errorf("%s %d: %v", format.name, i, err)
}
dst := reflect.ValueOf(tt.dst)
if dst.Kind() == reflect.Ptr {
dst = dst.Elem()
}
if !tt.test(dst.Interface()) {
t.Errorf("%s %d: unexpected result for %v: %v", format.name, i, tt.src, dst.Interface())
}
}
}
}
func TestInt2Codec(t *testing.T) {
testPgxCodec(t, "int2", []PgxTranscodeTestCase{
{int8(1), new(int16), isExpectedEq(int16(1))},
{int16(1), new(int16), isExpectedEq(int16(1))},
{int32(1), new(int16), isExpectedEq(int16(1))},
{int64(1), new(int16), isExpectedEq(int16(1))},
{uint8(1), new(int16), isExpectedEq(int16(1))},
{uint16(1), new(int16), isExpectedEq(int16(1))},
{uint32(1), new(int16), isExpectedEq(int16(1))},
{uint64(1), new(int16), isExpectedEq(int16(1))},
{int(1), new(int16), isExpectedEq(int16(1))},
{uint(1), new(int16), isExpectedEq(int16(1))},
{pgtype.Int2{Int: 1, Valid: true}, new(int16), isExpectedEq(int16(1))},
{1, new(int8), isExpectedEq(int8(1))},
{1, new(int16), isExpectedEq(int16(1))},
{1, new(int32), isExpectedEq(int32(1))},
{1, new(int64), isExpectedEq(int64(1))},
{1, new(uint8), isExpectedEq(uint8(1))},
{1, new(uint16), isExpectedEq(uint16(1))},
{1, new(uint32), isExpectedEq(uint32(1))},
{1, new(uint64), isExpectedEq(uint64(1))},
{1, new(int), isExpectedEq(int(1))},
{1, new(uint), isExpectedEq(uint(1))},
{math.MinInt16, new(int16), isExpectedEq(int16(math.MinInt16))},
{-1, new(int16), isExpectedEq(int16(-1))},
{0, new(int16), isExpectedEq(int16(0))},
{1, new(int16), isExpectedEq(int16(1))},
{math.MaxInt16, new(int16), isExpectedEq(int16(math.MaxInt16))},
{1, new(pgtype.Int2), isExpectedEq(pgtype.Int2{Int: 1, Valid: true})},
{pgtype.Int2{}, new(pgtype.Int2), isExpectedEq(pgtype.Int2{})},
{nil, new(*int16), isExpectedEq((*int16)(nil))},
})
}
func TestInt2Set(t *testing.T) {
successfulTests := []struct {
source interface{}
result pgtype.Int2
}{
{source: int8(1), result: pgtype.Int2{Int: 1, Valid: true}},
{source: int16(1), result: pgtype.Int2{Int: 1, Valid: true}},
{source: int32(1), result: pgtype.Int2{Int: 1, Valid: true}},
{source: int64(1), result: pgtype.Int2{Int: 1, Valid: true}},
{source: int8(-1), result: pgtype.Int2{Int: -1, Valid: true}},
{source: int16(-1), result: pgtype.Int2{Int: -1, Valid: true}},
{source: int32(-1), result: pgtype.Int2{Int: -1, Valid: true}},
{source: int64(-1), result: pgtype.Int2{Int: -1, Valid: true}},
{source: uint8(1), result: pgtype.Int2{Int: 1, Valid: true}},
{source: uint16(1), result: pgtype.Int2{Int: 1, Valid: true}},
{source: uint32(1), result: pgtype.Int2{Int: 1, Valid: true}},
{source: uint64(1), result: pgtype.Int2{Int: 1, Valid: true}},
{source: float32(1), result: pgtype.Int2{Int: 1, Valid: true}},
{source: float64(1), result: pgtype.Int2{Int: 1, Valid: true}},
{source: "1", result: pgtype.Int2{Int: 1, Valid: true}},
{source: _int8(1), result: pgtype.Int2{Int: 1, Valid: true}},
}
for i, tt := range successfulTests {
var r pgtype.Int2
err := r.Set(tt.source)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if r != tt.result {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
}
}
}
func TestInt2AssignTo(t *testing.T) {
var i8 int8
var i16 int16
var i32 int32
var i64 int64
var i int
var ui8 uint8
var ui16 uint16
var ui32 uint32
var ui64 uint64
var ui uint
var pi8 *int8
var _i8 _int8
var _pi8 *_int8
simpleTests := []struct {
src pgtype.Int2
dst interface{}
expected interface{}
}{
{src: pgtype.Int2{Int: 42, Valid: true}, dst: &i8, expected: int8(42)},
{src: pgtype.Int2{Int: 42, Valid: true}, dst: &i16, expected: int16(42)},
{src: pgtype.Int2{Int: 42, Valid: true}, dst: &i32, expected: int32(42)},
{src: pgtype.Int2{Int: 42, Valid: true}, dst: &i64, expected: int64(42)},
{src: pgtype.Int2{Int: 42, Valid: true}, dst: &i, expected: int(42)},
{src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)},
{src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)},
{src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)},
{src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)},
{src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui, expected: uint(42)},
{src: pgtype.Int2{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)},
{src: pgtype.Int2{Int: 0}, dst: &pi8, expected: ((*int8)(nil))},
{src: pgtype.Int2{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
pointerAllocTests := []struct {
src pgtype.Int2
dst interface{}
expected interface{}
}{
{src: pgtype.Int2{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)},
{src: pgtype.Int2{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)},
}
for i, tt := range pointerAllocTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
errorTests := []struct {
src pgtype.Int2
dst interface{}
}{
{src: pgtype.Int2{Int: 150, Valid: true}, dst: &i8},
{src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui8},
{src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui16},
{src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui32},
{src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui64},
{src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui},
{src: pgtype.Int2{Int: 0}, dst: &i16},
}
for i, tt := range errorTests {
err := tt.src.AssignTo(tt.dst)
if err == nil {
t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst)
}
}
}
+108 -93
View File
@@ -300,7 +300,7 @@ func NewConnInfo() *ConnInfo {
ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID})
ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID})
ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID})
ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}})
ci.RegisterDataType(DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}})
ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID})
ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID})
ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID})
@@ -324,7 +324,7 @@ func NewConnInfo() *ConnInfo {
ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID})
ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID})
ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID})
ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID, Codec: Int2Codec{}})
ci.RegisterDataType(DataType{Name: "int2", OID: Int2OID, Codec: Int2Codec{}})
ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID})
ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID})
ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID})
@@ -398,20 +398,10 @@ func NewConnInfo() *ConnInfo {
return ci
}
func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) {
for name, oid := range nameOIDs {
var value Value
if t, ok := nameValues[name]; ok {
value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value)
} else {
value = &GenericText{}
}
ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid})
}
}
func (ci *ConnInfo) RegisterDataType(t DataType) {
t.Value = NewValue(t.Value)
if t.Value != nil {
t.Value = NewValue(t.Value)
}
ci.oidToDataType[t.OID] = &t
ci.nameToDataType[t.Name] = &t
@@ -463,8 +453,10 @@ func (ci *ConnInfo) buildReflectTypeToDataType() {
ci.reflectTypeToDataType = make(map[reflect.Type]*DataType)
for _, dt := range ci.oidToDataType {
if _, is := dt.Value.(TypeValue); !is {
ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt
if dt.Value != nil {
if _, is := dt.Value.(TypeValue); !is {
ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt
}
}
}
@@ -583,8 +575,14 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode
} else {
switch formatCode {
case BinaryFormatCode:
if dt.binaryDecoder == nil {
return fmt.Errorf("dt.binaryDecoder is nil")
}
err = dt.binaryDecoder.DecodeBinary(ci, src)
case TextFormatCode:
if dt.textDecoder == nil {
return fmt.Errorf("dt.textDecoder is nil")
}
err = dt.textDecoder.DecodeText(ci, src)
}
}
@@ -782,14 +780,105 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt
return newPlan.Scan(ci, oid, formatCode, src, dst)
}
type pointerPointerScanPlan struct {
dstType reflect.Type
next ScanPlan
}
func (plan *pointerPointerScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if plan.dstType != reflect.TypeOf(dst) {
newPlan := ci.PlanScan(oid, formatCode, dst)
return newPlan.Scan(ci, oid, formatCode, src, dst)
}
el := reflect.ValueOf(dst).Elem()
if src == nil {
el.Set(reflect.Zero(el.Type()))
return nil
}
el.Set(reflect.New(el.Type().Elem()))
return plan.next.Scan(ci, oid, formatCode, src, el.Interface())
}
func tryPointerPointerScanPlan(dst interface{}) (plan *pointerPointerScanPlan, nextDst interface{}, ok bool) {
if dstValue := reflect.ValueOf(dst); dstValue.Kind() == reflect.Ptr {
elemValue := dstValue.Elem()
if elemValue.Kind() == reflect.Ptr {
plan = &pointerPointerScanPlan{dstType: dstValue.Type()}
return plan, reflect.Zero(elemValue.Type()).Interface(), true
}
}
return nil, nil, false
}
var elemKindToBasePointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{
reflect.Int: reflect.TypeOf(new(int)),
reflect.Int8: reflect.TypeOf(new(int8)),
reflect.Int16: reflect.TypeOf(new(int16)),
reflect.Int32: reflect.TypeOf(new(int32)),
reflect.Int64: reflect.TypeOf(new(int64)),
reflect.Uint: reflect.TypeOf(new(uint)),
reflect.Uint8: reflect.TypeOf(new(uint8)),
reflect.Uint16: reflect.TypeOf(new(uint16)),
reflect.Uint32: reflect.TypeOf(new(uint32)),
reflect.Uint64: reflect.TypeOf(new(uint64)),
reflect.Float32: reflect.TypeOf(new(float32)),
reflect.Float64: reflect.TypeOf(new(float64)),
reflect.String: reflect.TypeOf(new(string)),
}
type baseTypeScanPlan struct {
dstType reflect.Type
nextDstType reflect.Type
next ScanPlan
}
func (plan *baseTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if plan.dstType != reflect.TypeOf(dst) {
newPlan := ci.PlanScan(oid, formatCode, dst)
return newPlan.Scan(ci, oid, formatCode, src, dst)
}
return plan.next.Scan(ci, oid, formatCode, src, reflect.ValueOf(dst).Convert(plan.nextDstType).Interface())
}
func tryBaseTypeScanPlan(dst interface{}) (plan *baseTypeScanPlan, nextDst interface{}, ok bool) {
dstValue := reflect.ValueOf(dst)
if dstValue.Kind() == reflect.Ptr {
elemValue := dstValue.Elem()
nextDstType := elemKindToBasePointerTypes[elemValue.Kind()]
if nextDstType != nil {
return &baseTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true
}
}
return nil, nil, false
}
// PlanScan prepares a plan to scan a value into dst.
func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan {
if oid != 0 {
if dt, ok := ci.DataTypeForOID(oid); ok && dt.Codec != nil {
plan := dt.Codec.PlanScan(ci, oid, formatCode, dst, false)
if plan != nil {
if plan := dt.Codec.PlanScan(ci, oid, formatCode, dst, false); plan != nil {
return plan
}
if pointerPointerPlan, nextDst, ok := tryPointerPointerScanPlan(dst); ok {
if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil {
pointerPointerPlan.next = nextPlan
return pointerPointerPlan
}
}
if baseTypePlan, nextDst, ok := tryBaseTypeScanPlan(dst); ok {
if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil {
baseTypePlan.next = nextPlan
return baseTypePlan
}
}
}
}
@@ -908,77 +997,3 @@ func NewValue(v Value) Value {
}
var ErrScanTargetTypeChanged = errors.New("scan target type changed")
var nameValues map[string]Value
func init() {
nameValues = map[string]Value{
"_aclitem": &ACLItemArray{},
"_bool": &BoolArray{},
"_bpchar": &BPCharArray{},
"_bytea": &ByteaArray{},
"_cidr": &CIDRArray{},
"_date": &DateArray{},
"_float4": &Float4Array{},
"_float8": &Float8Array{},
"_inet": &InetArray{},
"_int2": &Int2Array{},
"_int4": &Int4Array{},
"_int8": &Int8Array{},
"_numeric": &NumericArray{},
"_text": &TextArray{},
"_timestamp": &TimestampArray{},
"_timestamptz": &TimestamptzArray{},
"_uuid": &UUIDArray{},
"_varchar": &VarcharArray{},
"_jsonb": &JSONBArray{},
"aclitem": &ACLItem{},
"bit": &Bit{},
"bool": &Bool{},
"box": &Box{},
"bpchar": &BPChar{},
"bytea": &Bytea{},
"char": &QChar{},
"cid": &CID{},
"cidr": &CIDR{},
"circle": &Circle{},
"date": &Date{},
"daterange": &Daterange{},
"float4": &Float4{},
"float8": &Float8{},
"hstore": &Hstore{},
"inet": &Inet{},
"int2": &Int2{},
"int4": &Int4{},
"int4range": &Int4range{},
"int8": &Int8{},
"int8range": &Int8range{},
"interval": &Interval{},
"json": &JSON{},
"jsonb": &JSONB{},
"line": &Line{},
"lseg": &Lseg{},
"macaddr": &Macaddr{},
"name": &Name{},
"numeric": &Numeric{},
"numrange": &Numrange{},
"oid": &OIDValue{},
"path": &Path{},
"point": &Point{},
"polygon": &Polygon{},
"record": &Record{},
"text": &Text{},
"tid": &TID{},
"timestamp": &Timestamp{},
"timestamptz": &Timestamptz{},
"tsrange": &Tsrange{},
"_tsrange": &TsrangeArray{},
"tstzrange": &Tstzrange{},
"_tstzrange": &TstzrangeArray{},
"unknown": &Unknown{},
"uuid": &UUID{},
"varbit": &Varbit{},
"varchar": &Varchar{},
"xid": &XID{},
}
}
-1
View File
@@ -1,4 +1,3 @@
erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go
erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go
erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go
erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go
-35
View File
@@ -1,35 +0,0 @@
package pgtype
import "fmt"
func (Int2) BinaryFormatSupported() bool {
return true
}
func (Int2) TextFormatSupported() bool {
return true
}
func (Int2) PreferredFormat() int16 {
return BinaryFormatCode
}
func (dst *Int2) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error {
switch format {
case BinaryFormatCode:
return dst.DecodeBinary(ci, src)
case TextFormatCode:
return dst.DecodeText(ci, src)
}
return fmt.Errorf("unknown format code %d", format)
}
func (src Int2) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) {
switch format {
case BinaryFormatCode:
return src.EncodeBinary(ci, buf)
case TextFormatCode:
return src.EncodeText(ci, buf)
}
return nil, fmt.Errorf("unknown format code %d", format)
}
+36 -37
View File
@@ -920,65 +920,64 @@ func TestQueryRowCoreIntegerDecoding(t *testing.T) {
}
failedDecodeTests := []struct {
sql string
scanArg interface{}
expectedErr string
sql string
scanArg interface{}
}{
// Check any integer type where value is outside Go:int8 range cannot be decoded
{"select 128::int2", &actual.i8, "is greater than"},
{"select 128::int4", &actual.i8, "is greater than"},
{"select 128::int8", &actual.i8, "is greater than"},
{"select -129::int2", &actual.i8, "is less than"},
{"select -129::int4", &actual.i8, "is less than"},
{"select -129::int8", &actual.i8, "is less than"},
{"select 128::int2", &actual.i8},
{"select 128::int4", &actual.i8},
{"select 128::int8", &actual.i8},
{"select -129::int2", &actual.i8},
{"select -129::int4", &actual.i8},
{"select -129::int8", &actual.i8},
// Check any integer type where value is outside Go:int16 range cannot be decoded
{"select 32768::int4", &actual.i16, "is greater than"},
{"select 32768::int8", &actual.i16, "is greater than"},
{"select -32769::int4", &actual.i16, "is less than"},
{"select -32769::int8", &actual.i16, "is less than"},
{"select 32768::int4", &actual.i16},
{"select 32768::int8", &actual.i16},
{"select -32769::int4", &actual.i16},
{"select -32769::int8", &actual.i16},
// Check any integer type where value is outside Go:int32 range cannot be decoded
{"select 2147483648::int8", &actual.i32, "is greater than"},
{"select -2147483649::int8", &actual.i32, "is less than"},
{"select 2147483648::int8", &actual.i32},
{"select -2147483649::int8", &actual.i32},
// Check any integer type where value is outside Go:uint range cannot be decoded
{"select -1::int2", &actual.ui, "is less than"},
{"select -1::int4", &actual.ui, "is less than"},
{"select -1::int8", &actual.ui, "is less than"},
{"select -1::int2", &actual.ui},
{"select -1::int4", &actual.ui},
{"select -1::int8", &actual.ui},
// Check any integer type where value is outside Go:uint8 range cannot be decoded
{"select 256::int2", &actual.ui8, "is greater than"},
{"select 256::int4", &actual.ui8, "is greater than"},
{"select 256::int8", &actual.ui8, "is greater than"},
{"select -1::int2", &actual.ui8, "is less than"},
{"select -1::int4", &actual.ui8, "is less than"},
{"select -1::int8", &actual.ui8, "is less than"},
{"select 256::int2", &actual.ui8},
{"select 256::int4", &actual.ui8},
{"select 256::int8", &actual.ui8},
{"select -1::int2", &actual.ui8},
{"select -1::int4", &actual.ui8},
{"select -1::int8", &actual.ui8},
// Check any integer type where value is outside Go:uint16 cannot be decoded
{"select 65536::int4", &actual.ui16, "is greater than"},
{"select 65536::int8", &actual.ui16, "is greater than"},
{"select -1::int2", &actual.ui16, "is less than"},
{"select -1::int4", &actual.ui16, "is less than"},
{"select -1::int8", &actual.ui16, "is less than"},
{"select 65536::int4", &actual.ui16},
{"select 65536::int8", &actual.ui16},
{"select -1::int2", &actual.ui16},
{"select -1::int4", &actual.ui16},
{"select -1::int8", &actual.ui16},
// Check any integer type where value is outside Go:uint32 range cannot be decoded
{"select 4294967296::int8", &actual.ui32, "is greater than"},
{"select -1::int2", &actual.ui32, "is less than"},
{"select -1::int4", &actual.ui32, "is less than"},
{"select -1::int8", &actual.ui32, "is less than"},
{"select 4294967296::int8", &actual.ui32},
{"select -1::int2", &actual.ui32},
{"select -1::int4", &actual.ui32},
{"select -1::int8", &actual.ui32},
// Check any integer type where value is outside Go:uint64 range cannot be decoded
{"select -1::int2", &actual.ui64, "is less than"},
{"select -1::int4", &actual.ui64, "is less than"},
{"select -1::int8", &actual.ui64, "is less than"},
{"select -1::int2", &actual.ui64},
{"select -1::int4", &actual.ui64},
{"select -1::int8", &actual.ui64},
}
for i, tt := range failedDecodeTests {
err := conn.QueryRow(context.Background(), tt.sql).Scan(tt.scanArg)
if err == nil {
t.Errorf("%d. Expected failure to decode, but unexpectedly succeeded: %v (sql -> %v)", i, err, tt.sql)
} else if !strings.Contains(err.Error(), tt.expectedErr) {
} else if !strings.Contains(err.Error(), "can't scan") {
t.Errorf("%d. Expected failure to decode, but got: %v (sql -> %v)", i, err, tt.sql)
}
+29 -20
View File
@@ -246,31 +246,40 @@ func (rows *connRows) Values() ([]interface{}, error) {
}
if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok {
value := dt.Value
if dt.Value != nil {
switch fd.Format {
case TextFormatCode:
decoder, ok := value.(pgtype.TextDecoder)
if !ok {
decoder = &pgtype.GenericText{}
value := dt.Value
switch fd.Format {
case TextFormatCode:
decoder, ok := value.(pgtype.TextDecoder)
if !ok {
decoder = &pgtype.GenericText{}
}
err := decoder.DecodeText(rows.connInfo, buf)
if err != nil {
rows.fatal(err)
}
values = append(values, decoder.(pgtype.Value).Get())
case BinaryFormatCode:
decoder, ok := value.(pgtype.BinaryDecoder)
if !ok {
decoder = &pgtype.GenericBinary{}
}
err := decoder.DecodeBinary(rows.connInfo, buf)
if err != nil {
rows.fatal(err)
}
values = append(values, value.Get())
default:
rows.fatal(errors.New("Unknown format code"))
}
err := decoder.DecodeText(rows.connInfo, buf)
} else if dt.Codec != nil {
value, err := dt.Codec.DecodeValue(rows.connInfo, fd.DataTypeOID, fd.Format, buf)
if err != nil {
rows.fatal(err)
}
values = append(values, decoder.(pgtype.Value).Get())
case BinaryFormatCode:
decoder, ok := value.(pgtype.BinaryDecoder)
if !ok {
decoder = &pgtype.GenericBinary{}
}
err := decoder.DecodeBinary(rows.connInfo, buf)
if err != nil {
rows.fatal(err)
}
values = append(values, value.Get())
default:
rows.fatal(errors.New("Unknown format code"))
values = append(values, value)
}
} else {
switch fd.Format {
+59 -34
View File
@@ -115,19 +115,30 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e
}
if dt, found := ci.DataTypeForValue(arg); found {
v := dt.Value
err := v.Set(arg)
if err != nil {
return nil, err
if dt.Value != nil {
v := dt.Value
err := v.Set(arg)
if err != nil {
return nil, err
}
buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
return string(buf), nil
} else if dt.Codec != nil {
buf, err := dt.Codec.Encode(ci, 0, TextFormatCode, arg, nil)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
return string(buf), nil
}
buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
return string(buf), nil
}
if refVal.Kind() == reflect.Ptr {
@@ -188,33 +199,47 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32
}
if dt, ok := ci.DataTypeForOID(oid); ok {
value := dt.Value
err := value.Set(arg)
if err != nil {
{
if arg, ok := arg.(driver.Valuer); ok {
v, err := callValuerValue(arg)
if err != nil {
return nil, err
if dt.Value != nil {
value := dt.Value
err := value.Set(arg)
if err != nil {
{
if arg, ok := arg.(driver.Valuer); ok {
v, err := callValuerValue(arg)
if err != nil {
return nil, err
}
return encodePreparedStatementArgument(ci, buf, oid, v)
}
return encodePreparedStatementArgument(ci, buf, oid, v)
}
return nil, err
}
return nil, err
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf)
if err != nil {
return nil, err
}
if argBuf != nil {
buf = argBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
return buf, nil
} else if dt.Codec != nil {
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
argBuf, err := dt.Codec.Encode(ci, oid, BinaryFormatCode, arg, buf)
if err != nil {
return nil, err
}
if argBuf != nil {
buf = argBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
return buf, nil
}
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf)
if err != nil {
return nil, err
}
if argBuf != nil {
buf = argBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
return buf, nil
}
if strippedArg, ok := stripNamedType(&refVal); ok {