diff --git a/aclitem_array.go b/aclitem_array.go index 2df0ccd4..09a64fb6 100644 --- a/aclitem_array.go +++ b/aclitem_array.go @@ -4,6 +4,7 @@ package pgtype import ( "database/sql/driver" + "reflect" errors "golang.org/x/xerrors" ) @@ -28,68 +29,94 @@ func (dst *ACLItemArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = ACLItemArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = ACLItemArray{Status: Null} - } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} - } else { - elements := make([]ACLItem, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ACLItemArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = ACLItemArray{Status: Null} - } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} - } else { - elements := make([]ACLItem, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ACLItemArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []ACLItem: - if value == nil { - *dst = ACLItemArray{Status: Null} - } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} - } else { - *dst = ACLItemArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for ACLItemArray", src) + } + if elementsLength == 0 { + *dst = ACLItemArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to ACLItemArray", value) + return errors.Errorf("cannot convert %v to ACLItemArray", src) + } + + *dst = ACLItemArray{ + Elements: make([]ACLItem, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]ACLItem, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to ACLItemArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *ACLItemArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to ACLItemArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in ACLItemArray", err) + } + index++ + + return index, nil +} + func (dst ACLItemArray) Get() interface{} { switch dst.Status { case Present: @@ -104,32 +131,26 @@ func (dst ACLItemArray) Get() interface{} { func (src *ACLItemArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -137,6 +158,49 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from ACLItemArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = ACLItemArray{Status: Null} diff --git a/aclitem_array_test.go b/aclitem_array_test.go index fb1e93fc..73e9ce71 100644 --- a/aclitem_array_test.go +++ b/aclitem_array_test.go @@ -69,6 +69,74 @@ func TestACLItemArraySet(t *testing.T) { source: (([]string)(nil)), result: pgtype.ACLItemArray{Status: pgtype.Null}, }, + { + source: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -88,6 +156,10 @@ func TestACLItemArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.ACLItemArray @@ -117,6 +189,78 @@ func TestACLItemArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + }, } for i, tt := range simpleTests { @@ -142,6 +286,33 @@ func TestACLItemArrayAssignTo(t *testing.T) { }, dst: &stringSlice, }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Status: pgtype.Present}, + {String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, } for i, tt := range errorTests { diff --git a/array.go b/array.go index bd3a993b..b779cd9d 100644 --- a/array.go +++ b/array.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "io" + "reflect" "strconv" "strings" "unicode" @@ -350,3 +351,24 @@ func QuoteArrayElementIfNeeded(src string) string { } return src } + +func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, elementsLength int) ([]ArrayDimension, int, bool) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + length := value.Len() + if 0 == elementsLength { + elementsLength = length + } else { + elementsLength *= length + } + dimensions = append(dimensions, ArrayDimension{Length: int32(length), LowerBound: 1}) + for i := 0; i < length; i++ { + if d, l, ok := findDimensionsFromValue(value.Index(i), dimensions, elementsLength); ok { + return d, l, true + } + } + } + return dimensions, elementsLength, true +} diff --git a/bool_array.go b/bool_array.go index a8c75a25..6569d5ca 100644 --- a/bool_array.go +++ b/bool_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *BoolArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = BoolArray{Status: Null} + return nil + } - case []bool: - if value == nil { - *dst = BoolArray{Status: Null} - } else if len(value) == 0 { - *dst = BoolArray{Status: Present} - } else { - elements := make([]Bool, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BoolArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*bool: - if value == nil { - *dst = BoolArray{Status: Null} - } else if len(value) == 0 { - *dst = BoolArray{Status: Present} - } else { - elements := make([]Bool, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BoolArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Bool: - if value == nil { - *dst = BoolArray{Status: Null} - } else if len(value) == 0 { - *dst = BoolArray{Status: Present} - } else { - *dst = BoolArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for BoolArray", src) + } + if elementsLength == 0 { + *dst = BoolArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to BoolArray", value) + return errors.Errorf("cannot convert %v to BoolArray", src) + } + + *dst = BoolArray{ + Elements: make([]Bool, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Bool, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to BoolArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *BoolArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to BoolArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in BoolArray", err) + } + index++ + + return index, nil +} + func (dst BoolArray) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst BoolArray) Get() interface{} { func (src *BoolArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]bool: - *v = make([]bool, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*bool: - *v = make([]*bool, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from BoolArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = BoolArray{Status: Null} diff --git a/bool_array_test.go b/bool_array_test.go index bef94622..7f31e252 100644 --- a/bool_array_test.go +++ b/bool_array_test.go @@ -68,6 +68,54 @@ func TestBoolArraySet(t *testing.T) { source: (([]bool)(nil)), result: pgtype.BoolArray{Status: pgtype.Null}, }, + { + source: [][]bool{{true}, {false}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]bool{{true}, {false}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -87,6 +135,10 @@ func TestBoolArrayAssignTo(t *testing.T) { var boolSlice []bool type _boolSlice []bool var namedBoolSlice _boolSlice + var boolSliceDim2 [][]bool + var boolSliceDim4 [][][][]bool + var boolArrayDim2 [2][1]bool + var boolArrayDim4 [2][1][1][3]bool simpleTests := []struct { src pgtype.BoolArray @@ -116,6 +168,58 @@ func TestBoolArrayAssignTo(t *testing.T) { dst: &boolSlice, expected: (([]bool)(nil)), }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]bool{{true}, {false}}, + dst: &boolSliceDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + dst: &boolSliceDim4, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]bool{{true}, {false}}, + dst: &boolArrayDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}, + {Bool: true, Status: pgtype.Present}, + {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + dst: &boolArrayDim4, + }, } for i, tt := range simpleTests { @@ -141,6 +245,27 @@ func TestBoolArrayAssignTo(t *testing.T) { }, dst: &boolSlice, }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &boolArrayDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &boolSlice, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}, {Bool: false, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &boolArrayDim4, + }, } for i, tt := range errorTests { diff --git a/bpchar_array.go b/bpchar_array.go index ed6fe703..8aef8330 100644 --- a/bpchar_array.go +++ b/bpchar_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *BPCharArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = BPCharArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = BPCharArray{Status: Null} - } else if len(value) == 0 { - *dst = BPCharArray{Status: Present} - } else { - elements := make([]BPChar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BPCharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = BPCharArray{Status: Null} - } else if len(value) == 0 { - *dst = BPCharArray{Status: Present} - } else { - elements := make([]BPChar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BPCharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []BPChar: - if value == nil { - *dst = BPCharArray{Status: Null} - } else if len(value) == 0 { - *dst = BPCharArray{Status: Present} - } else { - *dst = BPCharArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for BPCharArray", src) + } + if elementsLength == 0 { + *dst = BPCharArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to BPCharArray", value) + return errors.Errorf("cannot convert %v to BPCharArray", src) + } + + *dst = BPCharArray{ + Elements: make([]BPChar, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]BPChar, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to BPCharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *BPCharArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to BPCharArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in BPCharArray", err) + } + index++ + + return index, nil +} + func (dst BPCharArray) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst BPCharArray) Get() interface{} { func (src *BPCharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *BPCharArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from BPCharArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = BPCharArray{Status: Null} diff --git a/bytea_array.go b/bytea_array.go index 87d77f9e..3addb99a 100644 --- a/bytea_array.go +++ b/bytea_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,49 +31,94 @@ func (dst *ByteaArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = ByteaArray{Status: Null} + return nil + } - case [][]byte: - if value == nil { - *dst = ByteaArray{Status: Null} - } else if len(value) == 0 { - *dst = ByteaArray{Status: Present} - } else { - elements := make([]Bytea, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = ByteaArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Bytea: - if value == nil { - *dst = ByteaArray{Status: Null} - } else if len(value) == 0 { - *dst = ByteaArray{Status: Present} - } else { - *dst = ByteaArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for ByteaArray", src) + } + if elementsLength == 0 { + *dst = ByteaArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to ByteaArray", value) + return errors.Errorf("cannot convert %v to ByteaArray", src) + } + + *dst = ByteaArray{ + Elements: make([]Bytea, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Bytea, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to ByteaArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *ByteaArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to ByteaArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in ByteaArray", err) + } + index++ + + return index, nil +} + func (dst ByteaArray) Get() interface{} { switch dst.Status { case Present: @@ -87,23 +133,26 @@ func (dst ByteaArray) Get() interface{} { func (src *ByteaArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -111,6 +160,49 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from ByteaArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = ByteaArray{Status: Null} diff --git a/bytea_array_test.go b/bytea_array_test.go index a4eb2d91..f40005a2 100644 --- a/bytea_array_test.go +++ b/bytea_array_test.go @@ -68,6 +68,54 @@ func TestByteaArraySet(t *testing.T) { source: (([][]byte)(nil)), result: pgtype.ByteaArray{Status: pgtype.Null}, }, + { + source: [][][]byte{{{1}}, {{2}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1][]byte{{{1}}, {{2}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -85,6 +133,10 @@ func TestByteaArraySet(t *testing.T) { func TestByteaArrayAssignTo(t *testing.T) { var byteByteSlice [][]byte + var byteByteSliceDim2 [][][]byte + var byteByteSliceDim4 [][][][][]byte + var byteByteArraySliceDim2 [2][1][]byte + var byteByteArraySliceDim4 [2][1][1][3][]byte simpleTests := []struct { src pgtype.ByteaArray @@ -105,6 +157,58 @@ func TestByteaArrayAssignTo(t *testing.T) { dst: &byteByteSlice, expected: (([][]byte)(nil)), }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteByteSliceDim2, + expected: [][][]byte{{{1}}, {{2}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &byteByteSliceDim4, + expected: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Status: pgtype.Present}, {Bytes: []byte{2}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteByteArraySliceDim2, + expected: [2][1][]byte{{{1}}, {{2}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + {Bytes: []byte{4, 5, 6}, Status: pgtype.Present}, + {Bytes: []byte{7, 8, 9}, Status: pgtype.Present}, + {Bytes: []byte{10, 11, 12}, Status: pgtype.Present}, + {Bytes: []byte{13, 14, 15}, Status: pgtype.Present}, + {Bytes: []byte{16, 17, 18}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &byteByteArraySliceDim4, + expected: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + }, } for i, tt := range simpleTests { diff --git a/cidr_array.go b/cidr_array.go index a2e025cc..1ef2f428 100644 --- a/cidr_array.go +++ b/cidr_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "net" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,87 +31,94 @@ func (dst *CIDRArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = CIDRArray{Status: Null} + return nil + } - case []*net.IPNet: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []net.IP: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*net.IP: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []CIDR: - if value == nil { - *dst = CIDRArray{Status: Null} - } else if len(value) == 0 { - *dst = CIDRArray{Status: Present} - } else { - *dst = CIDRArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for CIDRArray", src) + } + if elementsLength == 0 { + *dst = CIDRArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to CIDRArray", value) + return errors.Errorf("cannot convert %v to CIDRArray", src) + } + + *dst = CIDRArray{ + Elements: make([]CIDR, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]CIDR, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to CIDRArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *CIDRArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to CIDRArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in CIDRArray", err) + } + index++ + + return index, nil +} + func (dst CIDRArray) Get() interface{} { switch dst.Status { case Present: @@ -126,41 +133,26 @@ func (dst CIDRArray) Get() interface{} { func (src *CIDRArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.IP: - *v = make([]*net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -168,6 +160,49 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from CIDRArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = CIDRArray{Status: Null} diff --git a/cidr_array_test.go b/cidr_array_test.go index 421aec4e..b1769c38 100644 --- a/cidr_array_test.go +++ b/cidr_array_test.go @@ -80,6 +80,74 @@ func TestCIDRArraySet(t *testing.T) { source: (([]net.IP)(nil)), result: pgtype.CIDRArray{Status: pgtype.Null}, }, + { + source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -98,6 +166,10 @@ func TestCIDRArraySet(t *testing.T) { func TestCIDRArrayAssignTo(t *testing.T) { var ipnetSlice []*net.IPNet var ipSlice []net.IP + var ipSliceDim2 [][]net.IP + var ipnetSliceDim4 [][][][]*net.IPNet + var ipArrayDim2 [2][1]net.IP + var ipnetArrayDim4 [2][1][1][3]*net.IPNet simpleTests := []struct { src pgtype.CIDRArray @@ -150,6 +222,78 @@ func TestCIDRArrayAssignTo(t *testing.T) { dst: &ipSlice, expected: (([]net.IP)(nil)), }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipSliceDim2, + expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetSliceDim4, + expected: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipArrayDim2, + expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetArrayDim4, + expected: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, } for i, tt := range simpleTests { diff --git a/date_array.go b/date_array.go index fe185f67..4ccdafe0 100644 --- a/date_array.go +++ b/date_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "time" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,68 +31,94 @@ func (dst *DateArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = DateArray{Status: Null} + return nil + } - case []time.Time: - if value == nil { - *dst = DateArray{Status: Null} - } else if len(value) == 0 { - *dst = DateArray{Status: Present} - } else { - elements := make([]Date, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = DateArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*time.Time: - if value == nil { - *dst = DateArray{Status: Null} - } else if len(value) == 0 { - *dst = DateArray{Status: Present} - } else { - elements := make([]Date, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = DateArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Date: - if value == nil { - *dst = DateArray{Status: Null} - } else if len(value) == 0 { - *dst = DateArray{Status: Present} - } else { - *dst = DateArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for DateArray", src) + } + if elementsLength == 0 { + *dst = DateArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to DateArray", value) + return errors.Errorf("cannot convert %v to DateArray", src) + } + + *dst = DateArray{ + Elements: make([]Date, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Date, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to DateArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *DateArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to DateArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in DateArray", err) + } + index++ + + return index, nil +} + func (dst DateArray) Get() interface{} { switch dst.Status { case Present: @@ -107,32 +133,26 @@ func (dst DateArray) Get() interface{} { func (src *DateArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -140,6 +160,49 @@ func (src *DateArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from DateArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = DateArray{Status: Null} diff --git a/date_array_test.go b/date_array_test.go index 9f4a96a9..089c7dd4 100644 --- a/date_array_test.go +++ b/date_array_test.go @@ -69,6 +69,78 @@ func TestDateArraySet(t *testing.T) { source: (([]time.Time)(nil)), result: pgtype.DateArray{Status: pgtype.Null}, }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -86,6 +158,10 @@ func TestDateArraySet(t *testing.T) { func TestDateArrayAssignTo(t *testing.T) { var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time simpleTests := []struct { src pgtype.DateArray @@ -106,6 +182,82 @@ func TestDateArrayAssignTo(t *testing.T) { dst: &timeSlice, expected: (([]time.Time)(nil)), }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, } for i, tt := range simpleTests { @@ -131,6 +283,33 @@ func TestDateArrayAssignTo(t *testing.T) { }, dst: &timeSlice, }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeSlice, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + }, } for i, tt := range errorTests { diff --git a/enum_array.go b/enum_array.go index 9312264c..2c83db24 100644 --- a/enum_array.go +++ b/enum_array.go @@ -4,6 +4,7 @@ package pgtype import ( "database/sql/driver" + "reflect" errors "golang.org/x/xerrors" ) @@ -28,68 +29,94 @@ func (dst *EnumArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = EnumArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = EnumArray{Status: Null} - } else if len(value) == 0 { - *dst = EnumArray{Status: Present} - } else { - elements := make([]GenericText, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = EnumArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = EnumArray{Status: Null} - } else if len(value) == 0 { - *dst = EnumArray{Status: Present} - } else { - elements := make([]GenericText, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = EnumArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []GenericText: - if value == nil { - *dst = EnumArray{Status: Null} - } else if len(value) == 0 { - *dst = EnumArray{Status: Present} - } else { - *dst = EnumArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for EnumArray", src) + } + if elementsLength == 0 { + *dst = EnumArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to EnumArray", value) + return errors.Errorf("cannot convert %v to EnumArray", src) + } + + *dst = EnumArray{ + Elements: make([]GenericText, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]GenericText, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to EnumArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *EnumArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to EnumArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in EnumArray", err) + } + index++ + + return index, nil +} + func (dst EnumArray) Get() interface{} { switch dst.Status { case Present: @@ -104,32 +131,26 @@ func (dst EnumArray) Get() interface{} { func (src *EnumArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -137,6 +158,49 @@ func (src *EnumArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from EnumArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = EnumArray{Status: Null} diff --git a/enum_array_test.go b/enum_array_test.go index 406c6b47..91a81ab6 100644 --- a/enum_array_test.go +++ b/enum_array_test.go @@ -67,6 +67,54 @@ func TestEnumArrayArraySet(t *testing.T) { source: (([]string)(nil)), result: pgtype.EnumArray{Status: pgtype.Null}, }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -86,6 +134,10 @@ func TestEnumArrayArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.EnumArray @@ -115,6 +167,58 @@ func TestEnumArrayArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, } for i, tt := range simpleTests { @@ -140,6 +244,27 @@ func TestEnumArrayArrayAssignTo(t *testing.T) { }, dst: &stringSlice, }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, } for i, tt := range errorTests { diff --git a/float4_array.go b/float4_array.go index 0e95c446..78d1a860 100644 --- a/float4_array.go +++ b/float4_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *Float4Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Float4Array{Status: Null} + return nil + } - case []float32: - if value == nil { - *dst = Float4Array{Status: Null} - } else if len(value) == 0 { - *dst = Float4Array{Status: Present} - } else { - elements := make([]Float4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*float32: - if value == nil { - *dst = Float4Array{Status: Null} - } else if len(value) == 0 { - *dst = Float4Array{Status: Present} - } else { - elements := make([]Float4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Float4: - if value == nil { - *dst = Float4Array{Status: Null} - } else if len(value) == 0 { - *dst = Float4Array{Status: Present} - } else { - *dst = Float4Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Float4Array", src) + } + if elementsLength == 0 { + *dst = Float4Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Float4Array", value) + return errors.Errorf("cannot convert %v to Float4Array", src) + } + + *dst = Float4Array{ + Elements: make([]Float4, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Float4, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Float4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Float4Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Float4Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Float4Array", err) + } + index++ + + return index, nil +} + func (dst Float4Array) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst Float4Array) Get() interface{} { func (src *Float4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]float32: - *v = make([]float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float32: - *v = make([]*float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Float4Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4Array{Status: Null} diff --git a/float4_array_test.go b/float4_array_test.go index 658b3381..23a94ee8 100644 --- a/float4_array_test.go +++ b/float4_array_test.go @@ -68,6 +68,54 @@ func TestFloat4ArraySet(t *testing.T) { source: (([]float32)(nil)), result: pgtype.Float4Array{Status: pgtype.Null}, }, + { + source: [][]float32{{1}, {2}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]float32{{1}, {2}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -86,6 +134,10 @@ func TestFloat4ArraySet(t *testing.T) { func TestFloat4ArrayAssignTo(t *testing.T) { var float32Slice []float32 var namedFloat32Slice _float32Slice + var float32SliceDim2 [][]float32 + var float32SliceDim4 [][][][]float32 + var float32ArrayDim2 [2][1]float32 + var float32ArrayDim4 [2][1][1][3]float32 simpleTests := []struct { src pgtype.Float4Array @@ -115,6 +167,58 @@ func TestFloat4ArrayAssignTo(t *testing.T) { dst: &float32Slice, expected: (([]float32)(nil)), }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]float32{{1}, {2}}, + dst: &float32SliceDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float32SliceDim4, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]float32{{1}, {2}}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float32ArrayDim4, + }, } for i, tt := range simpleTests { @@ -140,6 +244,27 @@ func TestFloat4ArrayAssignTo(t *testing.T) { }, dst: &float32Slice, }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32Slice, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/float8_array.go b/float8_array.go index 240e88d6..19223c52 100644 --- a/float8_array.go +++ b/float8_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *Float8Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Float8Array{Status: Null} + return nil + } - case []float64: - if value == nil { - *dst = Float8Array{Status: Null} - } else if len(value) == 0 { - *dst = Float8Array{Status: Present} - } else { - elements := make([]Float8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*float64: - if value == nil { - *dst = Float8Array{Status: Null} - } else if len(value) == 0 { - *dst = Float8Array{Status: Present} - } else { - elements := make([]Float8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Float8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Float8: - if value == nil { - *dst = Float8Array{Status: Null} - } else if len(value) == 0 { - *dst = Float8Array{Status: Present} - } else { - *dst = Float8Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Float8Array", src) + } + if elementsLength == 0 { + *dst = Float8Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Float8Array", value) + return errors.Errorf("cannot convert %v to Float8Array", src) + } + + *dst = Float8Array{ + Elements: make([]Float8, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Float8, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Float8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Float8Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Float8Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Float8Array", err) + } + index++ + + return index, nil +} + func (dst Float8Array) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst Float8Array) Get() interface{} { func (src *Float8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]float64: - *v = make([]float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float64: - *v = make([]*float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Float8Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8Array{Status: Null} diff --git a/float8_array_test.go b/float8_array_test.go index 2e29a19f..052ab3f3 100644 --- a/float8_array_test.go +++ b/float8_array_test.go @@ -68,6 +68,30 @@ func TestFloat8ArraySet(t *testing.T) { source: (([]float64)(nil)), result: pgtype.Float8Array{Status: pgtype.Null}, }, + { + source: [][]float64{{1}, {2}}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -86,6 +110,10 @@ func TestFloat8ArraySet(t *testing.T) { func TestFloat8ArrayAssignTo(t *testing.T) { var float64Slice []float64 var namedFloat64Slice _float64Slice + var float64SliceDim2 [][]float64 + var float64SliceDim4 [][][][]float64 + var float64ArrayDim2 [2][1]float64 + var float64ArrayDim4 [2][1][1][3]float64 simpleTests := []struct { src pgtype.Float8Array @@ -115,6 +143,58 @@ func TestFloat8ArrayAssignTo(t *testing.T) { dst: &float64Slice, expected: (([]float64)(nil)), }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]float64{{1}, {2}}, + dst: &float64SliceDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float64SliceDim4, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]float64{{1}, {2}}, + dst: &float64ArrayDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Status: pgtype.Present}, + {Float: 2, Status: pgtype.Present}, + {Float: 3, Status: pgtype.Present}, + {Float: 4, Status: pgtype.Present}, + {Float: 5, Status: pgtype.Present}, + {Float: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float64ArrayDim4, + }, } for i, tt := range simpleTests { @@ -140,6 +220,27 @@ func TestFloat8ArrayAssignTo(t *testing.T) { }, dst: &float64Slice, }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float64ArrayDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float64Slice, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}, {Float: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float64ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/hstore_array.go b/hstore_array.go index b258cbdd..8764aae7 100644 --- a/hstore_array.go +++ b/hstore_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,49 +31,94 @@ func (dst *HstoreArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = HstoreArray{Status: Null} + return nil + } - case []map[string]string: - if value == nil { - *dst = HstoreArray{Status: Null} - } else if len(value) == 0 { - *dst = HstoreArray{Status: Present} - } else { - elements := make([]Hstore, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = HstoreArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Hstore: - if value == nil { - *dst = HstoreArray{Status: Null} - } else if len(value) == 0 { - *dst = HstoreArray{Status: Present} - } else { - *dst = HstoreArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for HstoreArray", src) + } + if elementsLength == 0 { + *dst = HstoreArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to HstoreArray", value) + return errors.Errorf("cannot convert %v to HstoreArray", src) + } + + *dst = HstoreArray{ + Elements: make([]Hstore, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Hstore, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to HstoreArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *HstoreArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to HstoreArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in HstoreArray", err) + } + index++ + + return index, nil +} + func (dst HstoreArray) Get() interface{} { switch dst.Status { case Present: @@ -87,23 +133,26 @@ func (dst HstoreArray) Get() interface{} { func (src *HstoreArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]map[string]string: - *v = make([]map[string]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -111,6 +160,49 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from HstoreArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = HstoreArray{Status: Null} diff --git a/hstore_array_test.go b/hstore_array_test.go index 32b91840..fac66b4a 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -131,7 +131,7 @@ func TestHstoreArrayTranscode(t *testing.T) { func TestHstoreArraySet(t *testing.T) { successfulTests := []struct { - src []map[string]string + src interface{} result pgtype.HstoreArray }{ { @@ -147,6 +147,118 @@ func TestHstoreArraySet(t *testing.T) { Status: pgtype.Present, }, }, + { + src: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + }, + { + src: [][][][]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + }, + { + src: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + }, + { + src: [2][1][1][3]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + }, } for i, tt := range successfulTests { @@ -163,12 +275,16 @@ func TestHstoreArraySet(t *testing.T) { } func TestHstoreArrayAssignTo(t *testing.T) { - var m []map[string]string + var hstoreSlice []map[string]string + var hstoreSliceDim2 [][]map[string]string + var hstoreSliceDim4 [][][][]map[string]string + var hstoreArrayDim2 [2][1]map[string]string + var hstoreArrayDim4 [2][1][1][3]map[string]string simpleTests := []struct { src pgtype.HstoreArray - dst *[]map[string]string - expected []map[string]string + dst interface{} + expected interface{} }{ { src: pgtype.HstoreArray{ @@ -181,9 +297,127 @@ func TestHstoreArrayAssignTo(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, - dst: &m, + dst: &hstoreSlice, expected: []map[string]string{{"foo": "bar"}}}, - {src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &m, expected: (([]map[string]string)(nil))}, + { + src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &hstoreSliceDim2, + expected: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + dst: &hstoreSliceDim4, + expected: [][][][]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &hstoreArrayDim2, + expected: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present, + }, + dst: &hstoreArrayDim4, + expected: [2][1][1][3]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + }, } for i, tt := range simpleTests { @@ -192,8 +426,8 @@ func TestHstoreArrayAssignTo(t *testing.T) { t.Errorf("%d: %v", i, err) } - if !reflect.DeepEqual(*tt.dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } } diff --git a/inet_array.go b/inet_array.go index ca4c1a02..91f5d6e8 100644 --- a/inet_array.go +++ b/inet_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "net" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,87 +31,94 @@ func (dst *InetArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = InetArray{Status: Null} + return nil + } - case []*net.IPNet: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []net.IP: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*net.IP: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Inet: - if value == nil { - *dst = InetArray{Status: Null} - } else if len(value) == 0 { - *dst = InetArray{Status: Present} - } else { - *dst = InetArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for InetArray", src) + } + if elementsLength == 0 { + *dst = InetArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to InetArray", value) + return errors.Errorf("cannot convert %v to InetArray", src) + } + + *dst = InetArray{ + Elements: make([]Inet, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Inet, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to InetArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *InetArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to InetArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in InetArray", err) + } + index++ + + return index, nil +} + func (dst InetArray) Get() interface{} { switch dst.Status { case Present: @@ -126,41 +133,26 @@ func (dst InetArray) Get() interface{} { func (src *InetArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.IP: - *v = make([]*net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -168,6 +160,49 @@ func (src *InetArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from InetArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = InetArray{Status: Null} diff --git a/inet_array_test.go b/inet_array_test.go index 6737aac0..d78b91c0 100644 --- a/inet_array_test.go +++ b/inet_array_test.go @@ -80,6 +80,74 @@ func TestInetArraySet(t *testing.T) { source: (([]net.IP)(nil)), result: pgtype.InetArray{Status: pgtype.Null}, }, + { + source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -98,6 +166,10 @@ func TestInetArraySet(t *testing.T) { func TestInetArrayAssignTo(t *testing.T) { var ipnetSlice []*net.IPNet var ipSlice []net.IP + var ipSliceDim2 [][]net.IP + var ipnetSliceDim4 [][][][]*net.IPNet + var ipArrayDim2 [2][1]net.IP + var ipnetArrayDim4 [2][1][1][3]*net.IPNet simpleTests := []struct { src pgtype.InetArray @@ -150,6 +222,78 @@ func TestInetArrayAssignTo(t *testing.T) { dst: &ipSlice, expected: (([]net.IP)(nil)), }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipSliceDim2, + expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetSliceDim4, + expected: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &ipArrayDim2, + expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Status: pgtype.Present}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &ipnetArrayDim4, + expected: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, } for i, tt := range simpleTests { diff --git a/int2_array.go b/int2_array.go index ad2bd094..06febf01 100644 --- a/int2_array.go +++ b/int2_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,334 +31,94 @@ func (dst *Int2Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Int2Array{Status: Null} + return nil + } - case []int16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint16: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int32: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int32: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint32: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint32: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int64: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int64: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint64: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint64: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - elements := make([]Int2, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int2Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Int2: - if value == nil { - *dst = Int2Array{Status: Null} - } else if len(value) == 0 { - *dst = Int2Array{Status: Present} - } else { - *dst = Int2Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Int2Array", src) + } + if elementsLength == 0 { + *dst = Int2Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int2Array", value) + return errors.Errorf("cannot convert %v to Int2Array", src) + } + + *dst = Int2Array{ + Elements: make([]Int2, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int2, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Int2Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Int2Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Int2Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Int2Array", err) + } + index++ + + return index, nil +} + func (dst Int2Array) Get() interface{} { switch dst.Status { case Present: @@ -372,158 +133,26 @@ func (dst Int2Array) Get() interface{} { func (src *Int2Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -531,6 +160,49 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Int2Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2Array{Status: Null} diff --git a/int2_array_test.go b/int2_array_test.go index 22f71745..dfe84c19 100644 --- a/int2_array_test.go +++ b/int2_array_test.go @@ -110,6 +110,54 @@ func TestInt2ArraySet(t *testing.T) { source: (([]int16)(nil)), result: pgtype.Int2Array{Status: pgtype.Null}, }, + { + source: [][]int16{{1}, {2}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]int16{{1}, {2}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -129,6 +177,10 @@ func TestInt2ArrayAssignTo(t *testing.T) { var int16Slice []int16 var uint16Slice []uint16 var namedInt16Slice _int16Slice + var int16SliceDim2 [][]int16 + var int16SliceDim4 [][][][]int16 + var int16ArrayDim2 [2][1]int16 + var int16ArrayDim4 [2][1][1][3]int16 simpleTests := []struct { src pgtype.Int2Array @@ -167,6 +219,58 @@ func TestInt2ArrayAssignTo(t *testing.T) { dst: &int16Slice, expected: (([]int16)(nil)), }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]int16{{1}, {2}}, + dst: &int16SliceDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int16SliceDim4, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]int16{{1}, {2}}, + dst: &int16ArrayDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int16ArrayDim4, + }, } for i, tt := range simpleTests { @@ -200,6 +304,27 @@ func TestInt2ArrayAssignTo(t *testing.T) { }, dst: &uint16Slice, }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int16ArrayDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int16Slice, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &int16ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/int4_array.go b/int4_array.go index 15565f64..189bd238 100644 --- a/int4_array.go +++ b/int4_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,334 +31,94 @@ func (dst *Int4Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Int4Array{Status: Null} + return nil + } - case []int16: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int16: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint16: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint16: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint32: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int64: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int64: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint64: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint64: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - elements := make([]Int4, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int4Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Int4: - if value == nil { - *dst = Int4Array{Status: Null} - } else if len(value) == 0 { - *dst = Int4Array{Status: Present} - } else { - *dst = Int4Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Int4Array", src) + } + if elementsLength == 0 { + *dst = Int4Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int4Array", value) + return errors.Errorf("cannot convert %v to Int4Array", src) + } + + *dst = Int4Array{ + Elements: make([]Int4, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int4, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Int4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Int4Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Int4Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Int4Array", err) + } + index++ + + return index, nil +} + func (dst Int4Array) Get() interface{} { switch dst.Status { case Present: @@ -372,158 +133,26 @@ func (dst Int4Array) Get() interface{} { func (src *Int4Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -531,6 +160,49 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Int4Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4Array{Status: Null} diff --git a/int4_array_test.go b/int4_array_test.go index c839c1c9..35b791d3 100644 --- a/int4_array_test.go +++ b/int4_array_test.go @@ -116,6 +116,54 @@ func TestInt4ArraySet(t *testing.T) { source: (([]int32)(nil)), result: pgtype.Int4Array{Status: pgtype.Null}, }, + { + source: [][]int32{{1}, {2}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]int32{{1}, {2}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -143,6 +191,10 @@ func TestInt4ArrayAssignTo(t *testing.T) { var int32Slice []int32 var uint32Slice []uint32 var namedInt32Slice _int32Slice + var int32SliceDim2 [][]int32 + var int32SliceDim4 [][][][]int32 + var int32ArrayDim2 [2][1]int32 + var int32ArrayDim4 [2][1][1][3]int32 simpleTests := []struct { src pgtype.Int4Array @@ -181,6 +233,58 @@ func TestInt4ArrayAssignTo(t *testing.T) { dst: &int32Slice, expected: (([]int32)(nil)), }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]int32{{1}, {2}}, + dst: &int32SliceDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int32SliceDim4, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]int32{{1}, {2}}, + dst: &int32ArrayDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int32ArrayDim4, + }, } for i, tt := range simpleTests { @@ -214,6 +318,27 @@ func TestInt4ArrayAssignTo(t *testing.T) { }, dst: &uint32Slice, }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int32ArrayDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int32Slice, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &int32ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/int8_array.go b/int8_array.go index e8e8823a..edb232cb 100644 --- a/int8_array.go +++ b/int8_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,334 +31,94 @@ func (dst *Int8Array) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = Int8Array{Status: Null} + return nil + } - case []int16: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int16: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint16: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint16: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int32: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int32: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint32: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint32: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint64: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - elements := make([]Int8, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = Int8Array{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Int8: - if value == nil { - *dst = Int8Array{Status: Null} - } else if len(value) == 0 { - *dst = Int8Array{Status: Present} - } else { - *dst = Int8Array{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for Int8Array", src) + } + if elementsLength == 0 { + *dst = Int8Array{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to Int8Array", value) + return errors.Errorf("cannot convert %v to Int8Array", src) + } + + *dst = Int8Array{ + Elements: make([]Int8, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int8, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to Int8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *Int8Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to Int8Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in Int8Array", err) + } + index++ + + return index, nil +} + func (dst Int8Array) Get() interface{} { switch dst.Status { case Present: @@ -372,158 +133,26 @@ func (dst Int8Array) Get() interface{} { func (src *Int8Array) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]int16: - *v = make([]int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int16: - *v = make([]*int16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint16: - *v = make([]uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint16: - *v = make([]*uint16, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int32: - *v = make([]int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int32: - *v = make([]*int32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint32: - *v = make([]uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint32: - *v = make([]*uint32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int: - *v = make([]int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int: - *v = make([]*int, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint: - *v = make([]uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint: - *v = make([]*uint, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -531,6 +160,49 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from Int8Array") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8Array{Status: Null} diff --git a/int8_array_test.go b/int8_array_test.go index e9e7acfb..d65b875a 100644 --- a/int8_array_test.go +++ b/int8_array_test.go @@ -117,6 +117,54 @@ func TestInt8ArraySet(t *testing.T) { source: (([]int64)(nil)), result: pgtype.Int8Array{Status: pgtype.Null}, }, + { + source: [][]int64{{1}, {2}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]int64{{1}, {2}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -136,6 +184,10 @@ func TestInt8ArrayAssignTo(t *testing.T) { var int64Slice []int64 var uint64Slice []uint64 var namedInt64Slice _int64Slice + var int64SliceDim2 [][]int64 + var int64SliceDim4 [][][][]int64 + var int64ArrayDim2 [2][1]int64 + var int64ArrayDim4 [2][1][1][3]int64 simpleTests := []struct { src pgtype.Int8Array @@ -174,6 +226,58 @@ func TestInt8ArrayAssignTo(t *testing.T) { dst: &int64Slice, expected: (([]int64)(nil)), }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [][]int64{{1}, {2}}, + dst: &int64SliceDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int64SliceDim4, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + expected: [2][1]int64{{1}, {2}}, + dst: &int64ArrayDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Status: pgtype.Present}, + {Int: 2, Status: pgtype.Present}, + {Int: 3, Status: pgtype.Present}, + {Int: 4, Status: pgtype.Present}, + {Int: 5, Status: pgtype.Present}, + {Int: 6, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + expected: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int64ArrayDim4, + }, } for i, tt := range simpleTests { @@ -207,6 +311,27 @@ func TestInt8ArrayAssignTo(t *testing.T) { }, dst: &uint64Slice, }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int64ArrayDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &int64Slice, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &int64ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/jsonb_array.go b/jsonb_array.go index daebfa7b..c5a40a1d 100644 --- a/jsonb_array.go +++ b/jsonb_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,49 +31,94 @@ func (dst *JSONBArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = JSONBArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = JSONBArray{Status: Null} - } else if len(value) == 0 { - *dst = JSONBArray{Status: Present} - } else { - elements := make([]Text, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = JSONBArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Text: - if value == nil { - *dst = JSONBArray{Status: Null} - } else if len(value) == 0 { - *dst = JSONBArray{Status: Present} - } else { - *dst = JSONBArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for JSONBArray", src) + } + if elementsLength == 0 { + *dst = JSONBArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to JSONBArray", value) + return errors.Errorf("cannot convert %v to JSONBArray", src) + } + + *dst = JSONBArray{ + Elements: make([]Text, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Text, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to JSONBArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *JSONBArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to JSONBArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in JSONBArray", err) + } + index++ + + return index, nil +} + func (dst JSONBArray) Get() interface{} { switch dst.Status { case Present: @@ -87,23 +133,26 @@ func (dst JSONBArray) Get() interface{} { func (src *JSONBArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -111,6 +160,49 @@ func (src *JSONBArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from JSONBArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *JSONBArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = JSONBArray{Status: Null} diff --git a/macaddr_array.go b/macaddr_array.go index 616d6f85..398db1fe 100644 --- a/macaddr_array.go +++ b/macaddr_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "net" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,68 +31,94 @@ func (dst *MacaddrArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = MacaddrArray{Status: Null} + return nil + } - case []net.HardwareAddr: - if value == nil { - *dst = MacaddrArray{Status: Null} - } else if len(value) == 0 { - *dst = MacaddrArray{Status: Present} - } else { - elements := make([]Macaddr, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = MacaddrArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*net.HardwareAddr: - if value == nil { - *dst = MacaddrArray{Status: Null} - } else if len(value) == 0 { - *dst = MacaddrArray{Status: Present} - } else { - elements := make([]Macaddr, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = MacaddrArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Macaddr: - if value == nil { - *dst = MacaddrArray{Status: Null} - } else if len(value) == 0 { - *dst = MacaddrArray{Status: Present} - } else { - *dst = MacaddrArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for MacaddrArray", src) + } + if elementsLength == 0 { + *dst = MacaddrArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to MacaddrArray", value) + return errors.Errorf("cannot convert %v to MacaddrArray", src) + } + + *dst = MacaddrArray{ + Elements: make([]Macaddr, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Macaddr, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to MacaddrArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *MacaddrArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to MacaddrArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in MacaddrArray", err) + } + index++ + + return index, nil +} + func (dst MacaddrArray) Get() interface{} { switch dst.Status { case Present: @@ -107,32 +133,26 @@ func (dst MacaddrArray) Get() interface{} { func (src *MacaddrArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]net.HardwareAddr: - *v = make([]net.HardwareAddr, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.HardwareAddr: - *v = make([]*net.HardwareAddr, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -140,6 +160,49 @@ func (src *MacaddrArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from MacaddrArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *MacaddrArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = MacaddrArray{Status: Null} diff --git a/macaddr_array_test.go b/macaddr_array_test.go index d2b0a73b..647db8cf 100644 --- a/macaddr_array_test.go +++ b/macaddr_array_test.go @@ -44,6 +44,78 @@ func TestMacaddrArraySet(t *testing.T) { source: (([]net.HardwareAddr)(nil)), result: pgtype.MacaddrArray{Status: pgtype.Null}, }, + { + source: [][]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -61,6 +133,10 @@ func TestMacaddrArraySet(t *testing.T) { func TestMacaddrArrayAssignTo(t *testing.T) { var macaddrSlice []net.HardwareAddr + var macaddrSliceDim2 [][]net.HardwareAddr + var macaddrSliceDim4 [][][][]net.HardwareAddr + var macaddrArrayDim2 [2][1]net.HardwareAddr + var macaddrArrayDim4 [2][1][1][3]net.HardwareAddr simpleTests := []struct { src pgtype.MacaddrArray @@ -90,6 +166,82 @@ func TestMacaddrArrayAssignTo(t *testing.T) { dst: &macaddrSlice, expected: (([]net.HardwareAddr)(nil)), }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &macaddrSliceDim2, + expected: [][]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &macaddrSliceDim4, + expected: [][][][]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &macaddrArrayDim2, + expected: [2][1]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Status: pgtype.Present}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &macaddrArrayDim4, + expected: [2][1][1][3]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + }, } for i, tt := range simpleTests { diff --git a/numeric_array.go b/numeric_array.go index e086ca7a..dec81535 100644 --- a/numeric_array.go +++ b/numeric_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,182 +31,94 @@ func (dst *NumericArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = NumericArray{Status: Null} + return nil + } - case []float32: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*float32: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []float64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*float64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []int64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*int64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []uint64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*uint64: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - elements := make([]Numeric, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = NumericArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Numeric: - if value == nil { - *dst = NumericArray{Status: Null} - } else if len(value) == 0 { - *dst = NumericArray{Status: Present} - } else { - *dst = NumericArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for NumericArray", src) + } + if elementsLength == 0 { + *dst = NumericArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to NumericArray", value) + return errors.Errorf("cannot convert %v to NumericArray", src) + } + + *dst = NumericArray{ + Elements: make([]Numeric, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Numeric, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to NumericArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *NumericArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to NumericArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in NumericArray", err) + } + index++ + + return index, nil +} + func (dst NumericArray) Get() interface{} { switch dst.Status { case Present: @@ -220,86 +133,26 @@ func (dst NumericArray) Get() interface{} { func (src *NumericArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]float32: - *v = make([]float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float32: - *v = make([]*float32, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]float64: - *v = make([]float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*float64: - *v = make([]*float64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]int64: - *v = make([]int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*int64: - *v = make([]*int64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]uint64: - *v = make([]uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*uint64: - *v = make([]*uint64, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -307,6 +160,49 @@ func (src *NumericArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from NumericArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = NumericArray{Status: Null} diff --git a/numeric_array_test.go b/numeric_array_test.go index eafd31be..29300bf0 100644 --- a/numeric_array_test.go +++ b/numeric_array_test.go @@ -91,6 +91,54 @@ func TestNumericArraySet(t *testing.T) { source: (([]float32)(nil)), result: pgtype.NumericArray{Status: pgtype.Null}, }, + { + source: [][]float32{{1}, {2}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]float32{{1}, {2}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -109,6 +157,10 @@ func TestNumericArraySet(t *testing.T) { func TestNumericArrayAssignTo(t *testing.T) { var float32Slice []float32 var float64Slice []float64 + var float32SliceDim2 [][]float32 + var float32SliceDim4 [][][][]float32 + var float32ArrayDim2 [2][1]float32 + var float32ArrayDim4 [2][1][1][3]float32 simpleTests := []struct { src pgtype.NumericArray @@ -138,6 +190,58 @@ func TestNumericArrayAssignTo(t *testing.T) { dst: &float32Slice, expected: (([]float32)(nil)), }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32SliceDim2, + expected: [][]float32{{1}, {2}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &float32SliceDim4, + expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32ArrayDim2, + expected: [2][1]float32{{1}, {2}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Status: pgtype.Present}, + {Int: big.NewInt(2), Status: pgtype.Present}, + {Int: big.NewInt(3), Status: pgtype.Present}, + {Int: big.NewInt(4), Status: pgtype.Present}, + {Int: big.NewInt(5), Status: pgtype.Present}, + {Int: big.NewInt(6), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &float32ArrayDim4, + expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + }, } for i, tt := range simpleTests { @@ -163,6 +267,27 @@ func TestNumericArrayAssignTo(t *testing.T) { }, dst: &float32Slice, }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &float32Slice, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}, {Int: big.NewInt(2), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &float32ArrayDim4, + }, } for i, tt := range errorTests { diff --git a/text_array.go b/text_array.go index d1583557..31ed04ac 100644 --- a/text_array.go +++ b/text_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *TextArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = TextArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = TextArray{Status: Null} - } else if len(value) == 0 { - *dst = TextArray{Status: Present} - } else { - elements := make([]Text, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TextArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = TextArray{Status: Null} - } else if len(value) == 0 { - *dst = TextArray{Status: Present} - } else { - elements := make([]Text, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TextArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Text: - if value == nil { - *dst = TextArray{Status: Null} - } else if len(value) == 0 { - *dst = TextArray{Status: Present} - } else { - *dst = TextArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TextArray", src) + } + if elementsLength == 0 { + *dst = TextArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TextArray", value) + return errors.Errorf("cannot convert %v to TextArray", src) + } + + *dst = TextArray{ + Elements: make([]Text, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Text, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TextArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *TextArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to TextArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in TextArray", err) + } + index++ + + return index, nil +} + func (dst TextArray) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst TextArray) Get() interface{} { func (src *TextArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *TextArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from TextArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TextArray{Status: Null} diff --git a/text_array_test.go b/text_array_test.go index a29ce617..125d6034 100644 --- a/text_array_test.go +++ b/text_array_test.go @@ -68,6 +68,54 @@ func TestTextArraySet(t *testing.T) { source: (([]string)(nil)), result: pgtype.TextArray{Status: pgtype.Null}, }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -87,6 +135,10 @@ func TestTextArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.TextArray @@ -116,6 +168,58 @@ func TestTextArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, } for i, tt := range simpleTests { @@ -141,6 +245,27 @@ func TestTextArrayAssignTo(t *testing.T) { }, dst: &stringSlice, }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, } for i, tt := range errorTests { diff --git a/timestamp_array.go b/timestamp_array.go index 3b2c3141..355b29c5 100644 --- a/timestamp_array.go +++ b/timestamp_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "time" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,68 +31,94 @@ func (dst *TimestampArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = TimestampArray{Status: Null} + return nil + } - case []time.Time: - if value == nil { - *dst = TimestampArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestampArray{Status: Present} - } else { - elements := make([]Timestamp, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestampArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*time.Time: - if value == nil { - *dst = TimestampArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestampArray{Status: Present} - } else { - elements := make([]Timestamp, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestampArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Timestamp: - if value == nil { - *dst = TimestampArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestampArray{Status: Present} - } else { - *dst = TimestampArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TimestampArray", src) + } + if elementsLength == 0 { + *dst = TimestampArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TimestampArray", value) + return errors.Errorf("cannot convert %v to TimestampArray", src) + } + + *dst = TimestampArray{ + Elements: make([]Timestamp, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Timestamp, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TimestampArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *TimestampArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to TimestampArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in TimestampArray", err) + } + index++ + + return index, nil +} + func (dst TimestampArray) Get() interface{} { switch dst.Status { case Present: @@ -107,32 +133,26 @@ func (dst TimestampArray) Get() interface{} { func (src *TimestampArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -140,6 +160,49 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from TimestampArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestampArray{Status: Null} diff --git a/timestamp_array_test.go b/timestamp_array_test.go index d7632fa3..c6f32d20 100644 --- a/timestamp_array_test.go +++ b/timestamp_array_test.go @@ -85,6 +85,42 @@ func TestTimestampArraySet(t *testing.T) { source: (([]time.Time)(nil)), result: pgtype.TimestampArray{Status: pgtype.Null}, }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -102,6 +138,10 @@ func TestTimestampArraySet(t *testing.T) { func TestTimestampArrayAssignTo(t *testing.T) { var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time simpleTests := []struct { src pgtype.TimestampArray @@ -122,6 +162,82 @@ func TestTimestampArrayAssignTo(t *testing.T) { dst: &timeSlice, expected: (([]time.Time)(nil)), }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, } for i, tt := range simpleTests { @@ -147,6 +263,33 @@ func TestTimestampArrayAssignTo(t *testing.T) { }, dst: &timeSlice, }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeSlice, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + }, } for i, tt := range errorTests { diff --git a/timestamptz_array.go b/timestamptz_array.go index 3328ec05..94a791b6 100644 --- a/timestamptz_array.go +++ b/timestamptz_array.go @@ -5,7 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "time" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -31,68 +31,94 @@ func (dst *TimestamptzArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = TimestamptzArray{Status: Null} + return nil + } - case []time.Time: - if value == nil { - *dst = TimestamptzArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestamptzArray{Status: Present} - } else { - elements := make([]Timestamptz, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestamptzArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*time.Time: - if value == nil { - *dst = TimestamptzArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestamptzArray{Status: Present} - } else { - elements := make([]Timestamptz, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = TimestamptzArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Timestamptz: - if value == nil { - *dst = TimestamptzArray{Status: Null} - } else if len(value) == 0 { - *dst = TimestamptzArray{Status: Present} - } else { - *dst = TimestamptzArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TimestamptzArray", src) + } + if elementsLength == 0 { + *dst = TimestamptzArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TimestamptzArray", value) + return errors.Errorf("cannot convert %v to TimestamptzArray", src) + } + + *dst = TimestamptzArray{ + Elements: make([]Timestamptz, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Timestamptz, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TimestamptzArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *TimestamptzArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to TimestamptzArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in TimestamptzArray", err) + } + index++ + + return index, nil +} + func (dst TimestamptzArray) Get() interface{} { switch dst.Status { case Present: @@ -107,32 +133,26 @@ func (dst TimestamptzArray) Get() interface{} { func (src *TimestamptzArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]time.Time: - *v = make([]time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*time.Time: - *v = make([]*time.Time, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -140,6 +160,49 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from TimestamptzArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestamptzArray{Status: Null} diff --git a/timestamptz_array_test.go b/timestamptz_array_test.go index 8a4cfd1d..f4e80413 100644 --- a/timestamptz_array_test.go +++ b/timestamptz_array_test.go @@ -85,6 +85,78 @@ func TestTimestamptzArraySet(t *testing.T) { source: (([]time.Time)(nil)), result: pgtype.TimestamptzArray{Status: pgtype.Null}, }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -102,6 +174,10 @@ func TestTimestamptzArraySet(t *testing.T) { func TestTimestamptzArrayAssignTo(t *testing.T) { var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time simpleTests := []struct { src pgtype.TimestamptzArray @@ -122,6 +198,82 @@ func TestTimestamptzArrayAssignTo(t *testing.T) { dst: &timeSlice, expected: (([]time.Time)(nil)), }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, } for i, tt := range simpleTests { @@ -147,6 +299,33 @@ func TestTimestamptzArrayAssignTo(t *testing.T) { }, dst: &timeSlice, }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeArrayDim2, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &timeSlice, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &timeArrayDim4, + }, } for i, tt := range errorTests { diff --git a/tstzrange_array.go b/tstzrange_array.go index c19a9bfa..f5043c65 100644 --- a/tstzrange_array.go +++ b/tstzrange_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,30 +31,94 @@ func (dst *TstzrangeArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = TstzrangeArray{Status: Null} + return nil + } - case []Tstzrange: - if value == nil { - *dst = TstzrangeArray{Status: Null} - } else if len(value) == 0 { - *dst = TstzrangeArray{Status: Present} - } else { - *dst = TstzrangeArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for TstzrangeArray", src) + } + if elementsLength == 0 { + *dst = TstzrangeArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to TstzrangeArray", value) + return errors.Errorf("cannot convert %v to TstzrangeArray", src) + } + + *dst = TstzrangeArray{ + Elements: make([]Tstzrange, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Tstzrange, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to TstzrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *TstzrangeArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to TstzrangeArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in TstzrangeArray", err) + } + index++ + + return index, nil +} + func (dst TstzrangeArray) Get() interface{} { switch dst.Status { case Present: @@ -68,23 +133,26 @@ func (dst TstzrangeArray) Get() interface{} { func (src *TstzrangeArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]Tstzrange: - *v = make([]Tstzrange, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -92,6 +160,49 @@ func (src *TstzrangeArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from TstzrangeArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *TstzrangeArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TstzrangeArray{Status: Null} diff --git a/typed_array.go.erb b/typed_array.go.erb index a3deea5b..fb964ec8 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -30,51 +30,94 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { } } - switch value := src.(type) { - <% go_array_types.split(",").each do |t| %> - <% if t != "[]#{pgtype_element_type}" %> - case <%= t %>: - if value == nil { - *dst = <%= pgtype_array_type %>{Status: Null} - } else if len(value) == 0 { - *dst = <%= pgtype_array_type %>{Status: Present} - } else { - elements := make([]<%= pgtype_element_type %>, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = <%= pgtype_array_type %>{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - <% end %> - <% end %> - case []<%= pgtype_element_type %>: - if value == nil { - *dst = <%= pgtype_array_type %>{Status: Null} - } else if len(value) == 0 { - *dst = <%= pgtype_array_type %>{Status: Present} - } else { - *dst = <%= pgtype_array_type %>{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status : Present, - } - } - default: + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for <%= pgtype_array_type %>", src) + } + if elementsLength == 0 { + *dst = <%= pgtype_array_type %>{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>", value) + return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>", src) + } + + *dst = <%= pgtype_array_type %> { + Elements: make([]<%= pgtype_element_type %>, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]<%= pgtype_element_type %>, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to <%= pgtype_array_type %>, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *<%= pgtype_array_type %>) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to <%= pgtype_array_type %>") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in <%= pgtype_array_type %>", err) + } + index++ + + return index, nil +} + func (dst <%= pgtype_array_type %>) Get() interface{} { switch dst.Status { case Present: @@ -89,23 +132,26 @@ func (dst <%= pgtype_array_type %>) Get() interface{} { func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - <% go_array_types.split(",").each do |t| %> - case *<%= t %>: - *v = make(<%= t %>, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - <% end %> - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -113,6 +159,49 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from <%= pgtype_array_type %>") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 607d3bc3..8c594944 100755 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -1,27 +1,27 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time,[]*time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time,[]*time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32,[]*float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64,[]*float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go -erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr,[]*net.HardwareAddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go -erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string,[]*string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string,[]*string element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go -erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string,[]*string element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go -erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string,[]*string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go -erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go -erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go -erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go -erb pgtype_array_type=JSONBArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go +erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go +erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go +erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go +erb pgtype_array_type=NumericArray pgtype_element_type=Numeric element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go +erb pgtype_array_type=UUIDArray pgtype_element_type=UUID element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go +erb pgtype_array_type=JSONBArray pgtype_element_type=Text element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go # While the binary format is theoretically possible it is only practical to use the text format. -erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go +erb pgtype_array_type=EnumArray pgtype_element_type=GenericText text_null=NULL binary_format=false typed_array.go.erb > enum_array.go goimports -w *_array.go diff --git a/uuid_array.go b/uuid_array.go index 06d2d576..e2c86cf8 100644 --- a/uuid_array.go +++ b/uuid_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,106 +31,94 @@ func (dst *UUIDArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = UUIDArray{Status: Null} + return nil + } - case [][16]byte: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case [][]byte: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []string: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - elements := make([]UUID, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = UUIDArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []UUID: - if value == nil { - *dst = UUIDArray{Status: Null} - } else if len(value) == 0 { - *dst = UUIDArray{Status: Present} - } else { - *dst = UUIDArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for UUIDArray", src) + } + if elementsLength == 0 { + *dst = UUIDArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to UUIDArray", value) + return errors.Errorf("cannot convert %v to UUIDArray", src) + } + + *dst = UUIDArray{ + Elements: make([]UUID, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]UUID, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to UUIDArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *UUIDArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to UUIDArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in UUIDArray", err) + } + index++ + + return index, nil +} + func (dst UUIDArray) Get() interface{} { switch dst.Status { case Present: @@ -144,50 +133,26 @@ func (dst UUIDArray) Get() interface{} { func (src *UUIDArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[][16]byte: - *v = make([][16]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[][]byte: - *v = make([][]byte, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -195,6 +160,49 @@ func (src *UUIDArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from UUIDArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *UUIDArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = UUIDArray{Status: Null} diff --git a/uuid_array_test.go b/uuid_array_test.go index d5446920..cdb212bb 100644 --- a/uuid_array_test.go +++ b/uuid_array_test.go @@ -123,6 +123,78 @@ func TestUUIDArraySet(t *testing.T) { source: ([]string)(nil), result: pgtype.UUIDArray{Status: pgtype.Null}, }, + { + source: [][][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -142,6 +214,10 @@ func TestUUIDArrayAssignTo(t *testing.T) { var byteArraySlice [][16]byte var byteSliceSlice [][]byte var stringSlice []string + var byteArraySliceDim2 [][][16]byte + var stringSliceDim4 [][][][]string + var byteArrayDim2 [2][1][16]byte + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.UUIDArray @@ -190,6 +266,82 @@ func TestUUIDArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: ([]string)(nil), }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteArraySliceDim2, + expected: [][][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &byteArrayDim2, + expected: [2][1][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + }, } for i, tt := range simpleTests { diff --git a/varchar_array.go b/varchar_array.go index 32ca5941..ec378ed7 100644 --- a/varchar_array.go +++ b/varchar_array.go @@ -5,6 +5,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "reflect" "github.com/jackc/pgio" errors "golang.org/x/xerrors" @@ -30,68 +31,94 @@ func (dst *VarcharArray) Set(src interface{}) error { } } - switch value := src.(type) { + value := reflect.ValueOf(src) + if !value.IsValid() || value.IsZero() { + *dst = VarcharArray{Status: Null} + return nil + } - case []string: - if value == nil { - *dst = VarcharArray{Status: Null} - } else if len(value) == 0 { - *dst = VarcharArray{Status: Present} - } else { - elements := make([]Varchar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = VarcharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []*string: - if value == nil { - *dst = VarcharArray{Status: Null} - } else if len(value) == 0 { - *dst = VarcharArray{Status: Present} - } else { - elements := make([]Varchar, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = VarcharArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Status: Present, - } - } - - case []Varchar: - if value == nil { - *dst = VarcharArray{Status: Null} - } else if len(value) == 0 { - *dst = VarcharArray{Status: Present} - } else { - *dst = VarcharArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Status: Present, - } - } - default: + dimensions, elementsLength, ok := findDimensionsFromValue(reflect.ValueOf(src), nil, 0) + if !ok { + return errors.Errorf("cannot find dimensions of %v for VarcharArray", src) + } + if elementsLength == 0 { + *dst = VarcharArray{Status: Present} + return nil + } + if len(dimensions) == 0 { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return errors.Errorf("cannot convert %v to VarcharArray", value) + return errors.Errorf("cannot convert %v to VarcharArray", src) + } + + *dst = VarcharArray{ + Elements: make([]Varchar, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Varchar, elementsLength) + elementCount, err = dst.setRecursive(reflect.ValueOf(src), 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return errors.Errorf("cannot convert %v to VarcharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) } return nil } +func (dst *VarcharArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + if int32(value.Len()) != dst.Dimensions[dimension].Length { + return 0, errors.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < value.Len(); i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, errors.Errorf("cannot convert all values to VarcharArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, errors.Errorf("%v in VarcharArray", err) + } + index++ + + return index, nil +} + func (dst VarcharArray) Get() interface{} { switch dst.Status { case Present: @@ -106,32 +133,26 @@ func (dst VarcharArray) Get() interface{} { func (src *VarcharArray) AssignTo(dst interface{}) error { switch src.Status { case Present: - switch v := dst.(type) { - - case *[]string: - *v = make([]string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*string: - *v = make([]*string, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - default: + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if !value.CanSet() { if nextDst, retry := GetAssignToDstType(dst); retry { return src.AssignTo(nextDst) } return errors.Errorf("unable to assign to %T", dst) } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return errors.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil case Null: return NullAssignTo(dst) } @@ -139,6 +160,49 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + if value.Type().Len() != length { + return 0, errors.Errorf("expected size %d array, but %s has size %d array", length, value.Type(), value.Type().Len()) + } + value.Set(reflect.New(value.Type()).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, errors.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() || !value.Addr().CanInterface() { + return 0, errors.Errorf("cannot assign all values from VarcharArray") + } + err := src.Elements[index].AssignTo(value.Addr().Interface()) + if err != nil { + return 0, err + } + index++ + return index, nil +} + func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = VarcharArray{Status: Null} diff --git a/varchar_array_test.go b/varchar_array_test.go index 9ad80862..3b0e65ed 100644 --- a/varchar_array_test.go +++ b/varchar_array_test.go @@ -68,6 +68,54 @@ func TestVarcharArraySet(t *testing.T) { source: (([]string)(nil)), result: pgtype.VarcharArray{Status: pgtype.Null}, }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + }, } for i, tt := range successfulTests { @@ -87,6 +135,10 @@ func TestVarcharArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string simpleTests := []struct { src pgtype.VarcharArray @@ -116,6 +168,58 @@ func TestVarcharArrayAssignTo(t *testing.T) { dst: &stringSlice, expected: (([]string)(nil)), }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Status: pgtype.Present}, + {String: "bar", Status: pgtype.Present}, + {String: "baz", Status: pgtype.Present}, + {String: "wibble", Status: pgtype.Present}, + {String: "wobble", Status: pgtype.Present}, + {String: "wubble", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, } for i, tt := range simpleTests { @@ -141,6 +245,27 @@ func TestVarcharArrayAssignTo(t *testing.T) { }, dst: &stringSlice, }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringArrayDim2, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Status: pgtype.Present}, + dst: &stringSlice, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}, {String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + dst: &stringArrayDim4, + }, } for i, tt := range errorTests {