From 449a8a4f8e7a35dc38380391065a4f7122f9d21b Mon Sep 17 00:00:00 2001 From: Simo Haasanen Date: Fri, 7 Aug 2020 13:10:32 +0100 Subject: [PATCH] Add multidimensional array and slice support. Adds array support - previously only slices were supported. Adds new test cases for multidimensional arrays and slices. All previous test cases are unmodified and passed (fully backwards compatible). Removes hard-coded type conversions for arrays, instead now relies on the type support of the array element's type conversion support. Less maintenance for arrays, new type conversions are automatically supported when array's element gains new type support. Simplifies typed_array_gen.sh generator script by removing the hard-coded single-dimensional types for arrays. Only typed_array.go.erb and typed_array_gen.sh have been changed + 1 new auxiliary function in array.go file + additional tests in test files for each array. Other changes are from generated code. --- aclitem_array.go | 212 ++++++++----- aclitem_array_test.go | 171 +++++++++++ array.go | 22 ++ bool_array.go | 212 ++++++++----- bool_array_test.go | 125 ++++++++ bpchar_array.go | 212 ++++++++----- bytea_array.go | 184 +++++++++--- bytea_array_test.go | 104 +++++++ cidr_array.go | 241 ++++++++------- cidr_array_test.go | 144 +++++++++ date_array.go | 213 +++++++++----- date_array_test.go | 179 +++++++++++ enum_array.go | 212 ++++++++----- enum_array_test.go | 125 ++++++++ float4_array.go | 212 ++++++++----- float4_array_test.go | 125 ++++++++ float8_array.go | 212 ++++++++----- float8_array_test.go | 101 +++++++ hstore_array.go | 184 +++++++++--- hstore_array_test.go | 250 +++++++++++++++- inet_array.go | 241 ++++++++------- inet_array_test.go | 144 +++++++++ int2_array.go | 604 +++++++++----------------------------- int2_array_test.go | 125 ++++++++ int4_array.go | 604 +++++++++----------------------------- int4_array_test.go | 125 ++++++++ int8_array.go | 604 +++++++++----------------------------- int8_array_test.go | 125 ++++++++ jsonb_array.go | 184 +++++++++--- macaddr_array.go | 213 +++++++++----- macaddr_array_test.go | 152 ++++++++++ numeric_array.go | 380 +++++++++--------------- numeric_array_test.go | 125 ++++++++ text_array.go | 212 ++++++++----- text_array_test.go | 125 ++++++++ timestamp_array.go | 213 +++++++++----- timestamp_array_test.go | 143 +++++++++ timestamptz_array.go | 213 +++++++++----- timestamptz_array_test.go | 179 +++++++++++ tstzrange_array.go | 165 +++++++++-- typed_array.go.erb | 187 ++++++++---- typed_array_gen.sh | 46 +-- uuid_array.go | 268 +++++++++-------- uuid_array_test.go | 152 ++++++++++ varchar_array.go | 212 ++++++++----- varchar_array_test.go | 125 ++++++++ 46 files changed, 6193 insertions(+), 3113 deletions(-) diff --git a/aclitem_array.go b/aclitem_array.go index 2df0ccd4..09a64fb6 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -4,6 +4,7 @@ package pgtype import ( "database/sql/driver" + "reflect" errors "golang.org/x/xerrors" ) @@ -28,68 +29,94 @@ func (dst *ACLItemArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = ACLItemArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = ACLItemArray{Status: Null} - } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} - } else { - elements := make([]ACLItem, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ACLItemArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = ACLItemArray{Status: Null} - } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} - } else { - elements := make([]ACLItem, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ACLItemArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []ACLItem: - if value == nil { - *dst = ACLItemArray{Status: Null} - } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} - } else { - *dst = ACLItemArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for ACLItemArray", src) + } + if elementsLength == 0 { + *dst = ACLItemArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to ACLItemArray", value) + return errors.Errorf("cannot convert %v to ACLItemArray", src) + } + + *dst = ACLItemArray{ + Elements: make([]ACLItem, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]ACLItem, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to ACLItemArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *ACLItemArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to ACLItemArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in ACLItemArray", err) + } + index++ + + return index, nil +} + func (dst ACLItemArray) Get() interface{} { switch dst.Status { case Present: @@ -104,32 +131,26 @@ func (dst ACLItemArray) Get() interface{} { func (src *ACLItemArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -137,6 +158,49 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from ACLItemArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = ACLItemArray{Status: Null} diff --git a/aclitem_array_test.go b/aclitem_array_test.go index fb1e93fc..73e9ce71 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -69,6 +69,74 @@ func TestACLItemArraySet(t *testing.T) { source: (([]string)(nil)), result: pgtype.ACLItemArray{Status: pgtype.Null}, }, + { + source: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -88,6 +156,10 @@ func TestACLItemArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.ACLItemArray @@ -117,6 +189,78 @@ func TestACLItemArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + }, } for i, tt := range simpleTests { @@ -142,6 +286,33 @@ func TestACLItemArrayAssignTo(t *testing.T) { }, dst: &stringSlice, }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, } for i, tt := range errorTests { diff --git a/array.go b/array.go index bd3a993b..b779cd9d 100644 --- a/array.go +++ b/array.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "io" + "reflect" "strconv" "strings" "unicode" @@ -350,3 +351,24 @@ func QuoteArrayElementIfNeeded(src string) string { } return src } + +func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, elementsLength int) ([]ArrayDimension, int, bool) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + length := value.Len() + if 0 == elementsLength { + elementsLength = length + } else { + elementsLength *= length + } + dimensions = append(dimensions, ArrayDimension{Length: int32(length), LowerBound: 1}) + for i := 0; i < length; i++ { + if d, l, ok := findDimensionsFromValue(value.Index(i), dimensions, elementsLength); ok { + return d, l, true + } + } + } + return dimensions, elementsLength, true +} diff --git a/bool_array.go b/bool_array.go index a8c75a25..6569d5ca 100644 --- a/bool_array.go +++ b/bool_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *BoolArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = BoolArray{Status: Null} + return nil + } - case []bool: - if value == nil { - *dst = BoolArray{Status: Null} - } else if len(value) == 0 { - *dst = BoolArray{Status: Present} - } else { - elements := make([]Bool, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BoolArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*bool: - if value == nil { - *dst = BoolArray{Status: Null} - } else if len(value) == 0 { - *dst = BoolArray{Status: Present} - } else { - elements := make([]Bool, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BoolArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Bool: - if value == nil { - *dst = BoolArray{Status: Null} - } else if len(value) == 0 { - *dst = BoolArray{Status: Present} - } else { - *dst = BoolArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for BoolArray", src) + } + if elementsLength == 0 { + *dst = BoolArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to BoolArray", value) + return errors.Errorf("cannot convert %v to BoolArray", src) + } + + *dst = BoolArray{ + Elements: make([]Bool, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Bool, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to BoolArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *BoolArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to BoolArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in BoolArray", err) + } + index++ + + return index, nil +} + func (dst BoolArray) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst BoolArray) Get() interface{} { func (src *BoolArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]bool: - *v = make([]bool, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*bool: - *v = make([]*bool, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from BoolArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = BoolArray{Status: Null} diff --git a/bool_array_test.go b/bool_array_test.go index bef94622..7f31e252 100644 --- a/bool_array_test.go +++ b/bool_array_test.go @@ -68,6 +68,54 @@ func TestBoolArraySet(t *testing.T) { source: (([]bool)(nil)), result: pgtype.BoolArray{Status: pgtype.Null}, }, + { + source: [][]bool{{true}, {false}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]bool{{true}, {false}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -87,6 +135,10 @@ func TestBoolArrayAssignTo(t *testing.T) { var boolSlice []bool type _boolSlice []bool var namedBoolSlice _boolSlice + var boolSliceDim2 [][]bool + var boolSliceDim4 [][][][]bool + var boolArrayDim2 [2][1]bool + var boolArrayDim4 [2][1][1][3]bool simpleTests := []struct { src pgtype.BoolArray @@ -116,6 +168,58 @@ func TestBoolArrayAssignTo(t *testing.T) { dst: &boolSlice, expected: (([]bool)(nil)), }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]bool{{true}, {false}}, + dst: &boolSliceDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + dst: &boolSliceDim4, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]bool{{true}, {false}}, + dst: &boolArrayDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + dst: &boolArrayDim4, + }, } for i, tt := range simpleTests { @@ -141,6 +245,27 @@ func TestBoolArrayAssignTo(t *testing.T) { }, dst: &boolSlice, }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &boolArrayDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &boolSlice, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &boolArrayDim4, + }, } for i, tt := range errorTests { diff --git a/bpchar_array.go b/bpchar_array.go index ed6fe703..8aef8330 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *BPCharArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = BPCharArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = BPCharArray{Status: Null} - } else if len(value) == 0 { - *dst = BPCharArray{Status: Present} - } else { - elements := make([]BPChar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BPCharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = BPCharArray{Status: Null} - } else if len(value) == 0 { - *dst = BPCharArray{Status: Present} - } else { - elements := make([]BPChar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BPCharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []BPChar: - if value == nil { - *dst = BPCharArray{Status: Null} - } else if len(value) == 0 { - *dst = BPCharArray{Status: Present} - } else { - *dst = BPCharArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for BPCharArray", src) + } + if elementsLength == 0 { + *dst = BPCharArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to BPCharArray", value) + return errors.Errorf("cannot convert %v to BPCharArray", src) + } + + *dst = BPCharArray{ + Elements: make([]BPChar, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]BPChar, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to BPCharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *BPCharArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to BPCharArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in BPCharArray", err) + } + index++ + + return index, nil +} + func (dst BPCharArray) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst BPCharArray) Get() interface{} { func (src *BPCharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from BPCharArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = BPCharArray{Status: Null} diff --git a/bytea_array.go b/bytea_array.go index 87d77f9e..3addb99a 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,49 +31,94 @@ func (dst *ByteaArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = ByteaArray{Status: Null} + return nil + } - case [][]byte: - if value == nil { - *dst = ByteaArray{Status: Null} - } else if len(value) == 0 { - *dst = ByteaArray{Status: Present} - } else { - elements := make([]Bytea, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ByteaArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Bytea: - if value == nil { - *dst = ByteaArray{Status: Null} - } else if len(value) == 0 { - *dst = ByteaArray{Status: Present} - } else { - *dst = ByteaArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for ByteaArray", src) + } + if elementsLength == 0 { + *dst = ByteaArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to ByteaArray", value) + return errors.Errorf("cannot convert %v to ByteaArray", src) + } + + *dst = ByteaArray{ + Elements: make([]Bytea, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Bytea, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to ByteaArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *ByteaArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to ByteaArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in ByteaArray", err) + } + index++ + + return index, nil +} + func (dst ByteaArray) Get() interface{} { switch dst.Status { case Present: @@ -87,23 +133,26 @@ func (dst ByteaArray) Get() interface{} { func (src *ByteaArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -111,6 +160,49 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from ByteaArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = ByteaArray{Status: Null} diff --git a/bytea_array_test.go b/bytea_array_test.go index a4eb2d91..f40005a2 100644 --- a/bytea_array_test.go +++ b/bytea_array_test.go @@ -68,6 +68,54 @@ func TestByteaArraySet(t *testing.T) { source: (([][]byte)(nil)), result: pgtype.ByteaArray{Status: pgtype.Null}, }, + { + source: [][][]byte{{{1}}, {{2}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1][]byte{{{1}}, {{2}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -85,6 +133,10 @@ func TestByteaArraySet(t *testing.T) { func TestByteaArrayAssignTo(t *testing.T) { var byteByteSlice [][]byte + var byteByteSliceDim2 [][][]byte + var byteByteSliceDim4 [][][][][]byte + var byteByteArraySliceDim2 [2][1][]byte + var byteByteArraySliceDim4 [2][1][1][3][]byte simpleTests := []struct { src pgtype.ByteaArray @@ -105,6 +157,58 @@ func TestByteaArrayAssignTo(t *testing.T) { dst: &byteByteSlice, expected: (([][]byte)(nil)), }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteByteSliceDim2, + expected: [][][]byte{{{1}}, {{2}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &byteByteSliceDim4, + expected: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteByteArraySliceDim2, + expected: [2][1][]byte{{{1}}, {{2}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &byteByteArraySliceDim4, + expected: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + }, } for i, tt := range simpleTests { diff --git a/cidr_array.go b/cidr_array.go index a2e025cc..1ef2f428 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "net" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,87 +31,94 @@ func (dst *CIDRArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = CIDRArray{Status: Null} + return nil + } - case []*net.IPNet: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []net.IP: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*net.IP: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []CIDR: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - *dst = CIDRArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for CIDRArray", src) + } + if elementsLength == 0 { + *dst = CIDRArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to CIDRArray", value) + return errors.Errorf("cannot convert %v to CIDRArray", src) + } + + *dst = CIDRArray{ + Elements: make([]CIDR, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]CIDR, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to CIDRArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *CIDRArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to CIDRArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in CIDRArray", err) + } + index++ + + return index, nil +} + func (dst CIDRArray) Get() interface{} { switch dst.Status { case Present: @@ -126,41 +133,26 @@ func (dst CIDRArray) Get() interface{} { func (src *CIDRArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.IP: - *v = make([]*net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -168,6 +160,49 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from CIDRArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = CIDRArray{Status: Null} diff --git a/cidr_array_test.go b/cidr_array_test.go index 421aec4e..b1769c38 100644 --- a/cidr_array_test.go +++ b/cidr_array_test.go @@ -80,6 +80,74 @@ func TestCIDRArraySet(t *testing.T) { source: (([]net.IP)(nil)), result: pgtype.CIDRArray{Status: pgtype.Null}, }, + { + source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -98,6 +166,10 @@ func TestCIDRArraySet(t *testing.T) { func TestCIDRArrayAssignTo(t *testing.T) { var ipnetSlice []*net.IPNet var ipSlice []net.IP + var ipSliceDim2 [][]net.IP + var ipnetSliceDim4 [][][][]*net.IPNet + var ipArrayDim2 [2][1]net.IP + var ipnetArrayDim4 [2][1][1][3]*net.IPNet simpleTests := []struct { src pgtype.CIDRArray @@ -150,6 +222,78 @@ func TestCIDRArrayAssignTo(t *testing.T) { dst: &ipSlice, expected: (([]net.IP)(nil)), }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipSliceDim2, + expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetSliceDim4, + expected: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipArrayDim2, + expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetArrayDim4, + expected: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, } for i, tt := range simpleTests { diff --git a/date_array.go b/date_array.go index fe185f67..4ccdafe0 100644 --- a/date_array.go +++ b/date_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "time" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,68 +31,94 @@ func (dst *DateArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = DateArray{Status: Null} + return nil + } - case []time.Time: - if value == nil { - *dst = DateArray{Status: Null} - } else if len(value) == 0 { - *dst = DateArray{Status: Present} - } else { - elements := make([]Date, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = DateArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*time.Time: - if value == nil { - *dst = DateArray{Status: Null} - } else if len(value) == 0 { - *dst = DateArray{Status: Present} - } else { - elements := make([]Date, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = DateArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Date: - if value == nil { - *dst = DateArray{Status: Null} - } else if len(value) == 0 { - *dst = DateArray{Status: Present} - } else { - *dst = DateArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for DateArray", src) + } + if elementsLength == 0 { + *dst = DateArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to DateArray", value) + return errors.Errorf("cannot convert %v to DateArray", src) + } + + *dst = DateArray{ + Elements: make([]Date, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Date, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to DateArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *DateArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to DateArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in DateArray", err) + } + index++ + + return index, nil +} + func (dst DateArray) Get() interface{} { switch dst.Status { case Present: @@ -107,32 +133,26 @@ func (dst DateArray) Get() interface{} { func (src *DateArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -140,6 +160,49 @@ func (src *DateArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from DateArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = DateArray{Status: Null} diff --git a/date_array_test.go b/date_array_test.go index 9f4a96a9..089c7dd4 100644 --- a/date_array_test.go +++ b/date_array_test.go @@ -69,6 +69,78 @@ func TestDateArraySet(t *testing.T) { source: (([]time.Time)(nil)), result: pgtype.DateArray{Status: pgtype.Null}, }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -86,6 +158,10 @@ func TestDateArraySet(t *testing.T) { func TestDateArrayAssignTo(t *testing.T) { var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time simpleTests := []struct { src pgtype.DateArray @@ -106,6 +182,82 @@ func TestDateArrayAssignTo(t *testing.T) { dst: &timeSlice, expected: (([]time.Time)(nil)), }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, } for i, tt := range simpleTests { @@ -131,6 +283,33 @@ func TestDateArrayAssignTo(t *testing.T) { }, dst: &timeSlice, }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeSlice, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + }, } for i, tt := range errorTests { diff --git a/enum_array.go b/enum_array.go index 9312264c..2c83db24 100644 --- a/enum_array.go +++ b/enum_array.go @@ -4,6 +4,7 @@ package pgtype import ( "database/sql/driver" + "reflect" errors "golang.org/x/xerrors" ) @@ -28,68 +29,94 @@ func (dst *EnumArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = EnumArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = EnumArray{Status: Null} - } else if len(value) == 0 { - *dst = EnumArray{Status: Present} - } else { - elements := make([]GenericText, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = EnumArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = EnumArray{Status: Null} - } else if len(value) == 0 { - *dst = EnumArray{Status: Present} - } else { - elements := make([]GenericText, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = EnumArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []GenericText: - if value == nil { - *dst = EnumArray{Status: Null} - } else if len(value) == 0 { - *dst = EnumArray{Status: Present} - } else { - *dst = EnumArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for EnumArray", src) + } + if elementsLength == 0 { + *dst = EnumArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to EnumArray", value) + return errors.Errorf("cannot convert %v to EnumArray", src) + } + + *dst = EnumArray{ + Elements: make([]GenericText, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]GenericText, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to EnumArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *EnumArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to EnumArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in EnumArray", err) + } + index++ + + return index, nil +} + func (dst EnumArray) Get() interface{} { switch dst.Status { case Present: @@ -104,32 +131,26 @@ func (dst EnumArray) Get() interface{} { func (src *EnumArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -137,6 +158,49 @@ func (src *EnumArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from EnumArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = EnumArray{Status: Null} diff --git a/enum_array_test.go b/enum_array_test.go index 406c6b47..91a81ab6 100644 --- a/enum_array_test.go +++ b/enum_array_test.go @@ -67,6 +67,54 @@ func TestEnumArrayArraySet(t *testing.T) { source: (([]string)(nil)), result: pgtype.EnumArray{Status: pgtype.Null}, }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -86,6 +134,10 @@ func TestEnumArrayArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.EnumArray @@ -115,6 +167,58 @@ func TestEnumArrayArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, } for i, tt := range simpleTests { @@ -140,6 +244,27 @@ func TestEnumArrayArrayAssignTo(t *testing.T) { }, dst: &stringSlice, }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, } for i, tt := range errorTests { diff --git a/float4_array.go b/float4_array.go index 0e95c446..78d1a860 100644 --- a/float4_array.go +++ b/float4_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *Float4Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Float4Array{Status: Null} + return nil + } - case []float32: - if value == nil { - *dst = Float4Array{Status: Null} - } else if len(value) == 0 { - *dst = Float4Array{Status: Present} - } else { - elements := make([]Float4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*float32: - if value == nil { - *dst = Float4Array{Status: Null} - } else if len(value) == 0 { - *dst = Float4Array{Status: Present} - } else { - elements := make([]Float4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Float4: - if value == nil { - *dst = Float4Array{Status: Null} - } else if len(value) == 0 { - *dst = Float4Array{Status: Present} - } else { - *dst = Float4Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Float4Array", src) + } + if elementsLength == 0 { + *dst = Float4Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Float4Array", value) + return errors.Errorf("cannot convert %v to Float4Array", src) + } + + *dst = Float4Array{ + Elements: make([]Float4, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Float4, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Float4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Float4Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Float4Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Float4Array", err) + } + index++ + + return index, nil +} + func (dst Float4Array) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst Float4Array) Get() interface{} { func (src *Float4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]float32: - *v = make([]float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float32: - *v = make([]*float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Float4Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4Array{Status: Null} diff --git a/float4_array_test.go b/float4_array_test.go index 658b3381..23a94ee8 100644 --- a/float4_array_test.go +++ b/float4_array_test.go @@ -68,6 +68,54 @@ func TestFloat4ArraySet(t *testing.T) { source: (([]float32)(nil)), result: pgtype.Float4Array{Status: pgtype.Null}, }, + { + source: [][]float32{{1}, {2}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]float32{{1}, {2}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -86,6 +134,10 @@ func TestFloat4ArraySet(t *testing.T) { func TestFloat4ArrayAssignTo(t *testing.T) { var float32Slice []float32 var namedFloat32Slice _float32Slice + var float32SliceDim2 [][]float32 + var float32SliceDim4 [][][][]float32 + var float32ArrayDim2 [2][1]float32 + var float32ArrayDim4 [2][1][1][3]float32 simpleTests := []struct { src pgtype.Float4Array @@ -115,6 +167,58 @@ func TestFloat4ArrayAssignTo(t *testing.T) { dst: &float32Slice, expected: (([]float32)(nil)), }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]float32{{1}, {2}}, + dst: &float32SliceDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float32SliceDim4, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]float32{{1}, {2}}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float32ArrayDim4, + }, } for i, tt := range simpleTests { @@ -140,6 +244,27 @@ func TestFloat4ArrayAssignTo(t *testing.T) { }, dst: &float32Slice, }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32Slice, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/float8_array.go b/float8_array.go index 240e88d6..19223c52 100644 --- a/float8_array.go +++ b/float8_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *Float8Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Float8Array{Status: Null} + return nil + } - case []float64: - if value == nil { - *dst = Float8Array{Status: Null} - } else if len(value) == 0 { - *dst = Float8Array{Status: Present} - } else { - elements := make([]Float8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*float64: - if value == nil { - *dst = Float8Array{Status: Null} - } else if len(value) == 0 { - *dst = Float8Array{Status: Present} - } else { - elements := make([]Float8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Float8: - if value == nil { - *dst = Float8Array{Status: Null} - } else if len(value) == 0 { - *dst = Float8Array{Status: Present} - } else { - *dst = Float8Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Float8Array", src) + } + if elementsLength == 0 { + *dst = Float8Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Float8Array", value) + return errors.Errorf("cannot convert %v to Float8Array", src) + } + + *dst = Float8Array{ + Elements: make([]Float8, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Float8, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Float8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Float8Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Float8Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Float8Array", err) + } + index++ + + return index, nil +} + func (dst Float8Array) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst Float8Array) Get() interface{} { func (src *Float8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]float64: - *v = make([]float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float64: - *v = make([]*float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Float8Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8Array{Status: Null} diff --git a/float8_array_test.go b/float8_array_test.go index 2e29a19f..052ab3f3 100644 --- a/float8_array_test.go +++ b/float8_array_test.go @@ -68,6 +68,30 @@ func TestFloat8ArraySet(t *testing.T) { source: (([]float64)(nil)), result: pgtype.Float8Array{Status: pgtype.Null}, }, + { + source: [][]float64{{1}, {2}}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -86,6 +110,10 @@ func TestFloat8ArraySet(t *testing.T) { func TestFloat8ArrayAssignTo(t *testing.T) { var float64Slice []float64 var namedFloat64Slice _float64Slice + var float64SliceDim2 [][]float64 + var float64SliceDim4 [][][][]float64 + var float64ArrayDim2 [2][1]float64 + var float64ArrayDim4 [2][1][1][3]float64 simpleTests := []struct { src pgtype.Float8Array @@ -115,6 +143,58 @@ func TestFloat8ArrayAssignTo(t *testing.T) { dst: &float64Slice, expected: (([]float64)(nil)), }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]float64{{1}, {2}}, + dst: &float64SliceDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float64SliceDim4, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]float64{{1}, {2}}, + dst: &float64ArrayDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float64ArrayDim4, + }, } for i, tt := range simpleTests { @@ -140,6 +220,27 @@ func TestFloat8ArrayAssignTo(t *testing.T) { }, dst: &float64Slice, }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float64ArrayDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float64Slice, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float64ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/hstore_array.go b/hstore_array.go index b258cbdd..8764aae7 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,49 +31,94 @@ func (dst *HstoreArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = HstoreArray{Status: Null} + return nil + } - case []map[string]string: - if value == nil { - *dst = HstoreArray{Status: Null} - } else if len(value) == 0 { - *dst = HstoreArray{Status: Present} - } else { - elements := make([]Hstore, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = HstoreArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Hstore: - if value == nil { - *dst = HstoreArray{Status: Null} - } else if len(value) == 0 { - *dst = HstoreArray{Status: Present} - } else { - *dst = HstoreArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for HstoreArray", src) + } + if elementsLength == 0 { + *dst = HstoreArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to HstoreArray", value) + return errors.Errorf("cannot convert %v to HstoreArray", src) + } + + *dst = HstoreArray{ + Elements: make([]Hstore, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Hstore, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to HstoreArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *HstoreArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to HstoreArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in HstoreArray", err) + } + index++ + + return index, nil +} + func (dst HstoreArray) Get() interface{} { switch dst.Status { case Present: @@ -87,23 +133,26 @@ func (dst HstoreArray) Get() interface{} { func (src *HstoreArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]map[string]string: - *v = make([]map[string]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -111,6 +160,49 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from HstoreArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = HstoreArray{Status: Null} diff --git a/hstore_array_test.go b/hstore_array_test.go index 32b91840..fac66b4a 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -131,7 +131,7 @@ func TestHstoreArrayTranscode(t *testing.T) { func TestHstoreArraySet(t *testing.T) { successfulTests := []struct { - src []map[string]string + src interface{} result pgtype.HstoreArray }{ { @@ -147,6 +147,118 @@ func TestHstoreArraySet(t *testing.T) { Status: pgtype.Present, }, }, + { + src: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + }, + { + src: [][][][]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + }, + { + src: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + }, + { + src: [2][1][1][3]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + }, } for i, tt := range successfulTests { @@ -163,12 +275,16 @@ func TestHstoreArraySet(t *testing.T) { } func TestHstoreArrayAssignTo(t *testing.T) { - var m []map[string]string + var hstoreSlice []map[string]string + var hstoreSliceDim2 [][]map[string]string + var hstoreSliceDim4 [][][][]map[string]string + var hstoreArrayDim2 [2][1]map[string]string + var hstoreArrayDim4 [2][1][1][3]map[string]string simpleTests := []struct { src pgtype.HstoreArray - dst *[]map[string]string - expected []map[string]string + dst interface{} + expected interface{} }{ { src: pgtype.HstoreArray{ @@ -181,9 +297,127 @@ func TestHstoreArrayAssignTo(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, - dst: &m, + dst: &hstoreSlice, expected: []map[string]string{{"foo": "bar"}}}, - {src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &m, expected: (([]map[string]string)(nil))}, + { + src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &hstoreSliceDim2, + expected: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + dst: &hstoreSliceDim4, + expected: [][][][]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &hstoreArrayDim2, + expected: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + dst: &hstoreArrayDim4, + expected: [2][1][1][3]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + }, } for i, tt := range simpleTests { @@ -192,8 +426,8 @@ func TestHstoreArrayAssignTo(t *testing.T) { t.Errorf("%d: %v", i, err) } - if !reflect.DeepEqual(*tt.dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } } diff --git a/inet_array.go b/inet_array.go index ca4c1a02..91f5d6e8 100644 --- a/inet_array.go +++ b/inet_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "net" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,87 +31,94 @@ func (dst *InetArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = InetArray{Status: Null} + return nil + } - case []*net.IPNet: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []net.IP: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*net.IP: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Inet: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - *dst = InetArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for InetArray", src) + } + if elementsLength == 0 { + *dst = InetArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to InetArray", value) + return errors.Errorf("cannot convert %v to InetArray", src) + } + + *dst = InetArray{ + Elements: make([]Inet, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Inet, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to InetArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *InetArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to InetArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in InetArray", err) + } + index++ + + return index, nil +} + func (dst InetArray) Get() interface{} { switch dst.Status { case Present: @@ -126,41 +133,26 @@ func (dst InetArray) Get() interface{} { func (src *InetArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.IP: - *v = make([]*net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -168,6 +160,49 @@ func (src *InetArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from InetArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = InetArray{Status: Null} diff --git a/inet_array_test.go b/inet_array_test.go index 6737aac0..d78b91c0 100644 --- a/inet_array_test.go +++ b/inet_array_test.go @@ -80,6 +80,74 @@ func TestInetArraySet(t *testing.T) { source: (([]net.IP)(nil)), result: pgtype.InetArray{Status: pgtype.Null}, }, + { + source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -98,6 +166,10 @@ func TestInetArraySet(t *testing.T) { func TestInetArrayAssignTo(t *testing.T) { var ipnetSlice []*net.IPNet var ipSlice []net.IP + var ipSliceDim2 [][]net.IP + var ipnetSliceDim4 [][][][]*net.IPNet + var ipArrayDim2 [2][1]net.IP + var ipnetArrayDim4 [2][1][1][3]*net.IPNet simpleTests := []struct { src pgtype.InetArray @@ -150,6 +222,78 @@ func TestInetArrayAssignTo(t *testing.T) { dst: &ipSlice, expected: (([]net.IP)(nil)), }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipSliceDim2, + expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetSliceDim4, + expected: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipArrayDim2, + expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetArrayDim4, + expected: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, } for i, tt := range simpleTests { diff --git a/int2_array.go b/int2_array.go index ad2bd094..06febf01 100644 --- a/int2_array.go +++ b/int2_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,334 +31,94 @@ func (dst *Int2Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Int2Array{Status: Null} + return nil + } - case []int16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int32: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int32: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint32: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint32: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int64: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int64: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint64: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint64: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Int2: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - *dst = Int2Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Int2Array", src) + } + if elementsLength == 0 { + *dst = Int2Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int2Array", value) + return errors.Errorf("cannot convert %v to Int2Array", src) + } + + *dst = Int2Array{ + Elements: make([]Int2, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int2, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Int2Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Int2Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Int2Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Int2Array", err) + } + index++ + + return index, nil +} + func (dst Int2Array) Get() interface{} { switch dst.Status { case Present: @@ -372,158 +133,26 @@ func (dst Int2Array) Get() interface{} { func (src *Int2Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -531,6 +160,49 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Int2Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2Array{Status: Null} diff --git a/int2_array_test.go b/int2_array_test.go index 22f71745..dfe84c19 100644 --- a/int2_array_test.go +++ b/int2_array_test.go @@ -110,6 +110,54 @@ func TestInt2ArraySet(t *testing.T) { source: (([]int16)(nil)), result: pgtype.Int2Array{Status: pgtype.Null}, }, + { + source: [][]int16{{1}, {2}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]int16{{1}, {2}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -129,6 +177,10 @@ func TestInt2ArrayAssignTo(t *testing.T) { var int16Slice []int16 var uint16Slice []uint16 var namedInt16Slice _int16Slice + var int16SliceDim2 [][]int16 + var int16SliceDim4 [][][][]int16 + var int16ArrayDim2 [2][1]int16 + var int16ArrayDim4 [2][1][1][3]int16 simpleTests := []struct { src pgtype.Int2Array @@ -167,6 +219,58 @@ func TestInt2ArrayAssignTo(t *testing.T) { dst: &int16Slice, expected: (([]int16)(nil)), }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]int16{{1}, {2}}, + dst: &int16SliceDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int16SliceDim4, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]int16{{1}, {2}}, + dst: &int16ArrayDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int16ArrayDim4, + }, } for i, tt := range simpleTests { @@ -200,6 +304,27 @@ func TestInt2ArrayAssignTo(t *testing.T) { }, dst: &uint16Slice, }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int16ArrayDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int16Slice, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &int16ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/int4_array.go b/int4_array.go index 15565f64..189bd238 100644 --- a/int4_array.go +++ b/int4_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,334 +31,94 @@ func (dst *Int4Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Int4Array{Status: Null} + return nil + } - case []int16: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int16: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint16: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint16: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int64: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int64: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint64: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint64: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Int4: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - *dst = Int4Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Int4Array", src) + } + if elementsLength == 0 { + *dst = Int4Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int4Array", value) + return errors.Errorf("cannot convert %v to Int4Array", src) + } + + *dst = Int4Array{ + Elements: make([]Int4, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int4, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Int4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Int4Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Int4Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Int4Array", err) + } + index++ + + return index, nil +} + func (dst Int4Array) Get() interface{} { switch dst.Status { case Present: @@ -372,158 +133,26 @@ func (dst Int4Array) Get() interface{} { func (src *Int4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -531,6 +160,49 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Int4Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4Array{Status: Null} diff --git a/int4_array_test.go b/int4_array_test.go index c839c1c9..35b791d3 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -116,6 +116,54 @@ func TestInt4ArraySet(t *testing.T) { source: (([]int32)(nil)), result: pgtype.Int4Array{Status: pgtype.Null}, }, + { + source: [][]int32{{1}, {2}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]int32{{1}, {2}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -143,6 +191,10 @@ func TestInt4ArrayAssignTo(t *testing.T) { var int32Slice []int32 var uint32Slice []uint32 var namedInt32Slice _int32Slice + var int32SliceDim2 [][]int32 + var int32SliceDim4 [][][][]int32 + var int32ArrayDim2 [2][1]int32 + var int32ArrayDim4 [2][1][1][3]int32 simpleTests := []struct { src pgtype.Int4Array @@ -181,6 +233,58 @@ func TestInt4ArrayAssignTo(t *testing.T) { dst: &int32Slice, expected: (([]int32)(nil)), }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]int32{{1}, {2}}, + dst: &int32SliceDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int32SliceDim4, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]int32{{1}, {2}}, + dst: &int32ArrayDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int32ArrayDim4, + }, } for i, tt := range simpleTests { @@ -214,6 +318,27 @@ func TestInt4ArrayAssignTo(t *testing.T) { }, dst: &uint32Slice, }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int32ArrayDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int32Slice, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &int32ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/int8_array.go b/int8_array.go index e8e8823a..edb232cb 100644 --- a/int8_array.go +++ b/int8_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,334 +31,94 @@ func (dst *Int8Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Int8Array{Status: Null} + return nil + } - case []int16: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int16: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint16: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint16: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int32: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int32: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint32: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint32: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Int8: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - *dst = Int8Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Int8Array", src) + } + if elementsLength == 0 { + *dst = Int8Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int8Array", value) + return errors.Errorf("cannot convert %v to Int8Array", src) + } + + *dst = Int8Array{ + Elements: make([]Int8, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int8, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Int8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Int8Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Int8Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Int8Array", err) + } + index++ + + return index, nil +} + func (dst Int8Array) Get() interface{} { switch dst.Status { case Present: @@ -372,158 +133,26 @@ func (dst Int8Array) Get() interface{} { func (src *Int8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -531,6 +160,49 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Int8Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8Array{Status: Null} diff --git a/int8_array_test.go b/int8_array_test.go index e9e7acfb..d65b875a 100644 --- a/int8_array_test.go +++ b/int8_array_test.go @@ -117,6 +117,54 @@ func TestInt8ArraySet(t *testing.T) { source: (([]int64)(nil)), result: pgtype.Int8Array{Status: pgtype.Null}, }, + { + source: [][]int64{{1}, {2}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]int64{{1}, {2}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -136,6 +184,10 @@ func TestInt8ArrayAssignTo(t *testing.T) { var int64Slice []int64 var uint64Slice []uint64 var namedInt64Slice _int64Slice + var int64SliceDim2 [][]int64 + var int64SliceDim4 [][][][]int64 + var int64ArrayDim2 [2][1]int64 + var int64ArrayDim4 [2][1][1][3]int64 simpleTests := []struct { src pgtype.Int8Array @@ -174,6 +226,58 @@ func TestInt8ArrayAssignTo(t *testing.T) { dst: &int64Slice, expected: (([]int64)(nil)), }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]int64{{1}, {2}}, + dst: &int64SliceDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int64SliceDim4, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]int64{{1}, {2}}, + dst: &int64ArrayDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int64ArrayDim4, + }, } for i, tt := range simpleTests { @@ -207,6 +311,27 @@ func TestInt8ArrayAssignTo(t *testing.T) { }, dst: &uint64Slice, }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int64ArrayDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int64Slice, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &int64ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/jsonb_array.go b/jsonb_array.go index daebfa7b..c5a40a1d 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,49 +31,94 @@ func (dst *JSONBArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = JSONBArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = JSONBArray{Status: Null} - } else if len(value) == 0 { - *dst = JSONBArray{Status: Present} - } else { - elements := make([]Text, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = JSONBArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Text: - if value == nil { - *dst = JSONBArray{Status: Null} - } else if len(value) == 0 { - *dst = JSONBArray{Status: Present} - } else { - *dst = JSONBArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for JSONBArray", src) + } + if elementsLength == 0 { + *dst = JSONBArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to JSONBArray", value) + return errors.Errorf("cannot convert %v to JSONBArray", src) + } + + *dst = JSONBArray{ + Elements: make([]Text, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Text, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to JSONBArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *JSONBArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to JSONBArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in JSONBArray", err) + } + index++ + + return index, nil +} + func (dst JSONBArray) Get() interface{} { switch dst.Status { case Present: @@ -87,23 +133,26 @@ func (dst JSONBArray) Get() interface{} { func (src *JSONBArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -111,6 +160,49 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from JSONBArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *JSONBArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = JSONBArray{Status: Null} diff --git a/macaddr_array.go b/macaddr_array.go index 616d6f85..398db1fe 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "net" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,68 +31,94 @@ func (dst *MacaddrArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = MacaddrArray{Status: Null} + return nil + } - case []net.HardwareAddr: - if value == nil { - *dst = MacaddrArray{Status: Null} - } else if len(value) == 0 { - *dst = MacaddrArray{Status: Present} - } else { - elements := make([]Macaddr, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = MacaddrArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*net.HardwareAddr: - if value == nil { - *dst = MacaddrArray{Status: Null} - } else if len(value) == 0 { - *dst = MacaddrArray{Status: Present} - } else { - elements := make([]Macaddr, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = MacaddrArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Macaddr: - if value == nil { - *dst = MacaddrArray{Status: Null} - } else if len(value) == 0 { - *dst = MacaddrArray{Status: Present} - } else { - *dst = MacaddrArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for MacaddrArray", src) + } + if elementsLength == 0 { + *dst = MacaddrArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to MacaddrArray", value) + return errors.Errorf("cannot convert %v to MacaddrArray", src) + } + + *dst = MacaddrArray{ + Elements: make([]Macaddr, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Macaddr, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to MacaddrArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *MacaddrArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to MacaddrArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in MacaddrArray", err) + } + index++ + + return index, nil +} + func (dst MacaddrArray) Get() interface{} { switch dst.Status { case Present: @@ -107,32 +133,26 @@ func (dst MacaddrArray) Get() interface{} { func (src *MacaddrArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]net.HardwareAddr: - *v = make([]net.HardwareAddr, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.HardwareAddr: - *v = make([]*net.HardwareAddr, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -140,6 +160,49 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from MacaddrArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *MacaddrArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = MacaddrArray{Status: Null} diff --git a/macaddr_array_test.go b/macaddr_array_test.go index d2b0a73b..647db8cf 100644 --- a/macaddr_array_test.go +++ b/macaddr_array_test.go @@ -44,6 +44,78 @@ func TestMacaddrArraySet(t *testing.T) { source: (([]net.HardwareAddr)(nil)), result: pgtype.MacaddrArray{Status: pgtype.Null}, }, + { + source: [][]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -61,6 +133,10 @@ func TestMacaddrArraySet(t *testing.T) { func TestMacaddrArrayAssignTo(t *testing.T) { var macaddrSlice []net.HardwareAddr + var macaddrSliceDim2 [][]net.HardwareAddr + var macaddrSliceDim4 [][][][]net.HardwareAddr + var macaddrArrayDim2 [2][1]net.HardwareAddr + var macaddrArrayDim4 [2][1][1][3]net.HardwareAddr simpleTests := []struct { src pgtype.MacaddrArray @@ -90,6 +166,82 @@ func TestMacaddrArrayAssignTo(t *testing.T) { dst: &macaddrSlice, expected: (([]net.HardwareAddr)(nil)), }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &macaddrSliceDim2, + expected: [][]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &macaddrSliceDim4, + expected: [][][][]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &macaddrArrayDim2, + expected: [2][1]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &macaddrArrayDim4, + expected: [2][1][1][3]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + }, } for i, tt := range simpleTests { diff --git a/numeric_array.go b/numeric_array.go index e086ca7a..dec81535 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,182 +31,94 @@ func (dst *NumericArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = NumericArray{Status: Null} + return nil + } - case []float32: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*float32: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []float64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*float64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Numeric: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - *dst = NumericArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for NumericArray", src) + } + if elementsLength == 0 { + *dst = NumericArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to NumericArray", value) + return errors.Errorf("cannot convert %v to NumericArray", src) + } + + *dst = NumericArray{ + Elements: make([]Numeric, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Numeric, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to NumericArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *NumericArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to NumericArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in NumericArray", err) + } + index++ + + return index, nil +} + func (dst NumericArray) Get() interface{} { switch dst.Status { case Present: @@ -220,86 +133,26 @@ func (dst NumericArray) Get() interface{} { func (src *NumericArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]float32: - *v = make([]float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float32: - *v = make([]*float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]float64: - *v = make([]float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float64: - *v = make([]*float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -307,6 +160,49 @@ func (src *NumericArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from NumericArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = NumericArray{Status: Null} diff --git a/numeric_array_test.go b/numeric_array_test.go index eafd31be..29300bf0 100644 --- a/numeric_array_test.go +++ b/numeric_array_test.go @@ -91,6 +91,54 @@ func TestNumericArraySet(t *testing.T) { source: (([]float32)(nil)), result: pgtype.NumericArray{Status: pgtype.Null}, }, + { + source: [][]float32{{1}, {2}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]float32{{1}, {2}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -109,6 +157,10 @@ func TestNumericArraySet(t *testing.T) { func TestNumericArrayAssignTo(t *testing.T) { var float32Slice []float32 var float64Slice []float64 + var float32SliceDim2 [][]float32 + var float32SliceDim4 [][][][]float32 + var float32ArrayDim2 [2][1]float32 + var float32ArrayDim4 [2][1][1][3]float32 simpleTests := []struct { src pgtype.NumericArray @@ -138,6 +190,58 @@ func TestNumericArrayAssignTo(t *testing.T) { dst: &float32Slice, expected: (([]float32)(nil)), }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32SliceDim2, + expected: [][]float32{{1}, {2}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &float32SliceDim4, + expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32ArrayDim2, + expected: [2][1]float32{{1}, {2}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &float32ArrayDim4, + expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + }, } for i, tt := range simpleTests { @@ -163,6 +267,27 @@ func TestNumericArrayAssignTo(t *testing.T) { }, dst: &float32Slice, }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32Slice, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/text_array.go b/text_array.go index d1583557..31ed04ac 100644 --- a/text_array.go +++ b/text_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *TextArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = TextArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = TextArray{Status: Null} - } else if len(value) == 0 { - *dst = TextArray{Status: Present} - } else { - elements := make([]Text, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TextArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = TextArray{Status: Null} - } else if len(value) == 0 { - *dst = TextArray{Status: Present} - } else { - elements := make([]Text, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TextArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Text: - if value == nil { - *dst = TextArray{Status: Null} - } else if len(value) == 0 { - *dst = TextArray{Status: Present} - } else { - *dst = TextArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TextArray", src) + } + if elementsLength == 0 { + *dst = TextArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TextArray", value) + return errors.Errorf("cannot convert %v to TextArray", src) + } + + *dst = TextArray{ + Elements: make([]Text, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Text, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TextArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *TextArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to TextArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in TextArray", err) + } + index++ + + return index, nil +} + func (dst TextArray) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst TextArray) Get() interface{} { func (src *TextArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *TextArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from TextArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TextArray{Status: Null} diff --git a/text_array_test.go b/text_array_test.go index a29ce617..125d6034 100644 --- a/text_array_test.go +++ b/text_array_test.go @@ -68,6 +68,54 @@ func TestTextArraySet(t *testing.T) { source: (([]string)(nil)), result: pgtype.TextArray{Status: pgtype.Null}, }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -87,6 +135,10 @@ func TestTextArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.TextArray @@ -116,6 +168,58 @@ func TestTextArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, } for i, tt := range simpleTests { @@ -141,6 +245,27 @@ func TestTextArrayAssignTo(t *testing.T) { }, dst: &stringSlice, }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, } for i, tt := range errorTests { diff --git a/timestamp_array.go b/timestamp_array.go index 3b2c3141..355b29c5 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "time" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,68 +31,94 @@ func (dst *TimestampArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = TimestampArray{Status: Null} + return nil + } - case []time.Time: - if value == nil { - *dst = TimestampArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestampArray{Status: Present} - } else { - elements := make([]Timestamp, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestampArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*time.Time: - if value == nil { - *dst = TimestampArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestampArray{Status: Present} - } else { - elements := make([]Timestamp, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestampArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Timestamp: - if value == nil { - *dst = TimestampArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestampArray{Status: Present} - } else { - *dst = TimestampArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TimestampArray", src) + } + if elementsLength == 0 { + *dst = TimestampArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TimestampArray", value) + return errors.Errorf("cannot convert %v to TimestampArray", src) + } + + *dst = TimestampArray{ + Elements: make([]Timestamp, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Timestamp, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TimestampArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *TimestampArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to TimestampArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in TimestampArray", err) + } + index++ + + return index, nil +} + func (dst TimestampArray) Get() interface{} { switch dst.Status { case Present: @@ -107,32 +133,26 @@ func (dst TimestampArray) Get() interface{} { func (src *TimestampArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -140,6 +160,49 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from TimestampArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestampArray{Status: Null} diff --git a/timestamp_array_test.go b/timestamp_array_test.go index d7632fa3..c6f32d20 100644 --- a/timestamp_array_test.go +++ b/timestamp_array_test.go @@ -85,6 +85,42 @@ func TestTimestampArraySet(t *testing.T) { source: (([]time.Time)(nil)), result: pgtype.TimestampArray{Status: pgtype.Null}, }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -102,6 +138,10 @@ func TestTimestampArraySet(t *testing.T) { func TestTimestampArrayAssignTo(t *testing.T) { var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time simpleTests := []struct { src pgtype.TimestampArray @@ -122,6 +162,82 @@ func TestTimestampArrayAssignTo(t *testing.T) { dst: &timeSlice, expected: (([]time.Time)(nil)), }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, } for i, tt := range simpleTests { @@ -147,6 +263,33 @@ func TestTimestampArrayAssignTo(t *testing.T) { }, dst: &timeSlice, }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeSlice, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + }, } for i, tt := range errorTests { diff --git a/timestamptz_array.go b/timestamptz_array.go index 3328ec05..94a791b6 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "time" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,68 +31,94 @@ func (dst *TimestamptzArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = TimestamptzArray{Status: Null} + return nil + } - case []time.Time: - if value == nil { - *dst = TimestamptzArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestamptzArray{Status: Present} - } else { - elements := make([]Timestamptz, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestamptzArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*time.Time: - if value == nil { - *dst = TimestamptzArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestamptzArray{Status: Present} - } else { - elements := make([]Timestamptz, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestamptzArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Timestamptz: - if value == nil { - *dst = TimestamptzArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestamptzArray{Status: Present} - } else { - *dst = TimestamptzArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TimestamptzArray", src) + } + if elementsLength == 0 { + *dst = TimestamptzArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TimestamptzArray", value) + return errors.Errorf("cannot convert %v to TimestamptzArray", src) + } + + *dst = TimestamptzArray{ + Elements: make([]Timestamptz, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Timestamptz, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TimestamptzArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *TimestamptzArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to TimestamptzArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in TimestamptzArray", err) + } + index++ + + return index, nil +} + func (dst TimestamptzArray) Get() interface{} { switch dst.Status { case Present: @@ -107,32 +133,26 @@ func (dst TimestamptzArray) Get() interface{} { func (src *TimestamptzArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -140,6 +160,49 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from TimestamptzArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestamptzArray{Status: Null} diff --git a/timestamptz_array_test.go b/timestamptz_array_test.go index 8a4cfd1d..f4e80413 100644 --- a/timestamptz_array_test.go +++ b/timestamptz_array_test.go @@ -85,6 +85,78 @@ func TestTimestamptzArraySet(t *testing.T) { source: (([]time.Time)(nil)), result: pgtype.TimestamptzArray{Status: pgtype.Null}, }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -102,6 +174,10 @@ func TestTimestamptzArraySet(t *testing.T) { func TestTimestamptzArrayAssignTo(t *testing.T) { var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time simpleTests := []struct { src pgtype.TimestamptzArray @@ -122,6 +198,82 @@ func TestTimestamptzArrayAssignTo(t *testing.T) { dst: &timeSlice, expected: (([]time.Time)(nil)), }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, } for i, tt := range simpleTests { @@ -147,6 +299,33 @@ func TestTimestamptzArrayAssignTo(t *testing.T) { }, dst: &timeSlice, }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeSlice, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + }, } for i, tt := range errorTests { diff --git a/tstzrange_array.go b/tstzrange_array.go index c19a9bfa..f5043c65 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,30 +31,94 @@ func (dst *TstzrangeArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = TstzrangeArray{Status: Null} + return nil + } - case []Tstzrange: - if value == nil { - *dst = TstzrangeArray{Status: Null} - } else if len(value) == 0 { - *dst = TstzrangeArray{Status: Present} - } else { - *dst = TstzrangeArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TstzrangeArray", src) + } + if elementsLength == 0 { + *dst = TstzrangeArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TstzrangeArray", value) + return errors.Errorf("cannot convert %v to TstzrangeArray", src) + } + + *dst = TstzrangeArray{ + Elements: make([]Tstzrange, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Tstzrange, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TstzrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *TstzrangeArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to TstzrangeArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in TstzrangeArray", err) + } + index++ + + return index, nil +} + func (dst TstzrangeArray) Get() interface{} { switch dst.Status { case Present: @@ -68,23 +133,26 @@ func (dst TstzrangeArray) Get() interface{} { func (src *TstzrangeArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]Tstzrange: - *v = make([]Tstzrange, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -92,6 +160,49 @@ func (src *TstzrangeArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from TstzrangeArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *TstzrangeArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TstzrangeArray{Status: Null} diff --git a/typed_array.go.erb b/typed_array.go.erb index a3deea5b..fb964ec8 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -30,51 +30,94 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { } } - switch value := src.(type) { - <% go_array_types.split(",").each do |t| %> - <% if t != "[]#{pgtype_element_type}" %> - case <%= t %>: - if value == nil { - *dst = <%= pgtype_array_type %>{Status: Null} - } else if len(value) == 0 { - *dst = <%= pgtype_array_type %>{Status: Present} - } else { - elements := make([]<%= pgtype_element_type %>, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = <%= pgtype_array_type %>{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - <% end %> - <% end %> - case []<%= pgtype_element_type %>: - if value == nil { - *dst = <%= pgtype_array_type %>{Status: Null} - } else if len(value) == 0 { - *dst = <%= pgtype_array_type %>{Status: Present} - } else { - *dst = <%= pgtype_array_type %>{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status : Present, - } - } - default: + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for <%= pgtype_array_type %>", src) + } + if elementsLength == 0 { + *dst = <%= pgtype_array_type %>{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>", value) + return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>", src) + } + + *dst = <%= pgtype_array_type %> { + Elements: make([]<%= pgtype_element_type %>, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]<%= pgtype_element_type %>, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *<%= pgtype_array_type %>) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to <%= pgtype_array_type %>") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in <%= pgtype_array_type %>", err) + } + index++ + + return index, nil +} + func (dst <%= pgtype_array_type %>) Get() interface{} { switch dst.Status { case Present: @@ -89,23 +132,26 @@ func (dst <%= pgtype_array_type %>) Get() interface{} { func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - <% go_array_types.split(",").each do |t| %> - case *<%= t %>: - *v = make(<%= t %>, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - <% end %> - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -113,6 +159,49 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from <%= pgtype_array_type %>") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 607d3bc3..8c594944 100755 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,27 +1,27 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time,[]*time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time,[]*time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32,[]*float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64,[]*float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go -erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr,[]*net.HardwareAddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go -erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string,[]*string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string,[]*string element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go -erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string,[]*string element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go -erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string,[]*string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go -erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go -erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go -erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go -erb pgtype_array_type=JSONBArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go +erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go +erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go +erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go +erb pgtype_array_type=NumericArray pgtype_element_type=Numeric element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go +erb pgtype_array_type=UUIDArray pgtype_element_type=UUID element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go +erb pgtype_array_type=JSONBArray pgtype_element_type=Text element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go # While the binary format is theoretically possible it is only practical to use the text format. -erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go +erb pgtype_array_type=EnumArray pgtype_element_type=GenericText text_null=NULL binary_format=false typed_array.go.erb > enum_array.go goimports -w *_array.go diff --git a/uuid_array.go b/uuid_array.go index 06d2d576..e2c86cf8 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,106 +31,94 @@ func (dst *UUIDArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = UUIDArray{Status: Null} + return nil + } - case [][16]byte: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case [][]byte: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []string: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []UUID: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - *dst = UUIDArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for UUIDArray", src) + } + if elementsLength == 0 { + *dst = UUIDArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to UUIDArray", value) + return errors.Errorf("cannot convert %v to UUIDArray", src) + } + + *dst = UUIDArray{ + Elements: make([]UUID, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]UUID, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to UUIDArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *UUIDArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to UUIDArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in UUIDArray", err) + } + index++ + + return index, nil +} + func (dst UUIDArray) Get() interface{} { switch dst.Status { case Present: @@ -144,50 +133,26 @@ func (dst UUIDArray) Get() interface{} { func (src *UUIDArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[][16]byte: - *v = make([][16]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -195,6 +160,49 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from UUIDArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *UUIDArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = UUIDArray{Status: Null} diff --git a/uuid_array_test.go b/uuid_array_test.go index d5446920..cdb212bb 100644 --- a/uuid_array_test.go +++ b/uuid_array_test.go @@ -123,6 +123,78 @@ func TestUUIDArraySet(t *testing.T) { source: ([]string)(nil), result: pgtype.UUIDArray{Status: pgtype.Null}, }, + { + source: [][][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -142,6 +214,10 @@ func TestUUIDArrayAssignTo(t *testing.T) { var byteArraySlice [][16]byte var byteSliceSlice [][]byte var stringSlice []string + var byteArraySliceDim2 [][][16]byte + var stringSliceDim4 [][][][]string + var byteArrayDim2 [2][1][16]byte + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.UUIDArray @@ -190,6 +266,82 @@ func TestUUIDArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: ([]string)(nil), }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteArraySliceDim2, + expected: [][][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteArrayDim2, + expected: [2][1][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + }, } for i, tt := range simpleTests { diff --git a/varchar_array.go b/varchar_array.go index 32ca5941..ec378ed7 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *VarcharArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = VarcharArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = VarcharArray{Status: Null} - } else if len(value) == 0 { - *dst = VarcharArray{Status: Present} - } else { - elements := make([]Varchar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = VarcharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = VarcharArray{Status: Null} - } else if len(value) == 0 { - *dst = VarcharArray{Status: Present} - } else { - elements := make([]Varchar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = VarcharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Varchar: - if value == nil { - *dst = VarcharArray{Status: Null} - } else if len(value) == 0 { - *dst = VarcharArray{Status: Present} - } else { - *dst = VarcharArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for VarcharArray", src) + } + if elementsLength == 0 { + *dst = VarcharArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to VarcharArray", value) + return errors.Errorf("cannot convert %v to VarcharArray", src) + } + + *dst = VarcharArray{ + Elements: make([]Varchar, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Varchar, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to VarcharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *VarcharArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to VarcharArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in VarcharArray", err) + } + index++ + + return index, nil +} + func (dst VarcharArray) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst VarcharArray) Get() interface{} { func (src *VarcharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from VarcharArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = VarcharArray{Status: Null} diff --git a/varchar_array_test.go b/varchar_array_test.go index 9ad80862..3b0e65ed 100644 --- a/varchar_array_test.go +++ b/varchar_array_test.go @@ -68,6 +68,54 @@ func TestVarcharArraySet(t *testing.T) { source: (([]string)(nil)), result: pgtype.VarcharArray{Status: pgtype.Null}, }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -87,6 +135,10 @@ func TestVarcharArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.VarcharArray @@ -116,6 +168,58 @@ func TestVarcharArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, } for i, tt := range simpleTests { @@ -141,6 +245,27 @@ func TestVarcharArrayAssignTo(t *testing.T) { }, dst: &stringSlice, }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, } for i, tt := range errorTests {