diff --git a/Rakefile b/Rakefile index f3a61a09..3fe26cb5 100644 --- a/Rakefile +++ b/Rakefile @@ -7,6 +7,7 @@ rule '.go' => '.go.erb' do |task| end generated_code_files = [ + "pgtype/array_getter_setter.go", "pgtype/int.go", "pgtype/int_test.go", "pgtype/integration_benchmark_test.go", diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 5a02c435..94d24fc9 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -16,6 +16,9 @@ type ArrayGetter interface { // Index returns the element at i. Index(i int) interface{} + + // IndexType returns a non-nil scan target of the type Index will return. This is used by ArrayCodec.PlanEncode. + IndexType() interface{} } // ArraySetter is a type can be set from a PostgreSQL array. @@ -27,6 +30,10 @@ type ArraySetter interface { // ScanIndex returns a value usable as a scan target for i. SetDimensions must be called before ScanIndex. ScanIndex(i int) interface{} + + // ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by + // ArrayCodec.PlanScan. + ScanIndexType() interface{} } // ArrayCodec is a codec for any array type. @@ -43,6 +50,18 @@ func (c *ArrayCodec) PreferredFormat() int16 { } func (c *ArrayCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + arrayValuer, ok := value.(ArrayGetter) + if !ok { + return nil + } + + elementType := arrayValuer.IndexType() + + elementEncodePlan := ci.PlanEncode(c.ElementDataType.OID, format, elementType) + if elementEncodePlan == nil { + return nil + } + switch format { case BinaryFormatCode: return &encodePlanArrayCodecBinary{ac: c, ci: ci, oid: oid} @@ -60,10 +79,7 @@ type encodePlanArrayCodecText struct { } func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - array, err := makeArrayGetter(value) - if err != nil { - return nil, err - } + array := value.(ArrayGetter) dimensions := array.Dimensions() if dimensions == nil { @@ -142,10 +158,7 @@ type encodePlanArrayCodecBinary struct { } func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - array, err := makeArrayGetter(value) - if err != nil { - return nil, err - } + array := value.(ArrayGetter) dimensions := array.Dimensions() if dimensions == nil { @@ -198,8 +211,15 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB } func (c *ArrayCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { - _, err := makeArraySetter(target) - if err != nil { + arrayScanner, ok := target.(ArraySetter) + if !ok { + return nil + } + + elementType := arrayScanner.ScanIndexType() + + elementScanPlan := ci.PlanScan(c.ElementDataType.OID, format, elementType) + if _, ok := elementScanPlan.(*scanPlanFail); ok { return nil } @@ -300,10 +320,11 @@ func (c *ArrayCodec) decodeText(ci *ConnInfo, arrayOID uint32, src []byte, array } type scanPlanArrayCodec struct { - arrayCodec *ArrayCodec - ci *ConnInfo - oid uint32 - formatCode int16 + arrayCodec *ArrayCodec + ci *ConnInfo + oid uint32 + formatCode int16 + elementScanPlan ScanPlan } func (spac *scanPlanArrayCodec) Scan(src []byte, dst interface{}) error { @@ -312,11 +333,7 @@ func (spac *scanPlanArrayCodec) Scan(src []byte, dst interface{}) error { oid := spac.oid formatCode := spac.formatCode - array, err := makeArraySetter(dst) - if err != nil { - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(src, dst) - } + array := dst.(ArraySetter) if src == nil { return array.SetDimensions(nil) @@ -358,3 +375,26 @@ func (c *ArrayCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []b err := ci.PlanScan(oid, format, &slice).Scan(src, &slice) return slice, err } + +func isRagged(slice reflect.Value) bool { + if slice.Type().Elem().Kind() != reflect.Slice { + return false + } + + sliceLen := slice.Len() + innerLen := 0 + for i := 0; i < sliceLen; i++ { + if i == 0 { + innerLen = slice.Index(i).Len() + } else { + if slice.Index(i).Len() != innerLen { + return true + } + } + if isRagged(slice.Index(i)) { + return true + } + } + + return false +} diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index 0c31dcee..b4b9b6a7 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -108,3 +108,60 @@ func TestArrayCodecDecodeValue(t *testing.T) { }) } } + +func TestArrayCodecScanMultipleDimensions(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + rows, err := conn.Query(context.Background(), `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) + require.NoError(t, err) + + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss) + } + + require.NoError(t, rows.Err()) +} + +func TestArrayCodecScanWrongMultipleDimensions(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + rows, err := conn.Query(context.Background(), `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) + require.NoError(t, err) + + for rows.Next() { + var ss [][][]int32 + err := rows.Scan(&ss) + require.Error(t, err, "can't scan into dest[0]: PostgreSQL array has 2 dimensions but slice has 3 dimensions") + } +} + +func TestArrayCodecEncodeMultipleDimensions(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + rows, err := conn.Query(context.Background(), `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}) + require.NoError(t, err) + + for rows.Next() { + var ss [][]int32 + err := rows.Scan(&ss) + require.NoError(t, err) + require.Equal(t, [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, ss) + } + + require.NoError(t, rows.Err()) +} + +func TestArrayCodecEncodeMultipleDimensionsRagged(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + rows, err := conn.Query(context.Background(), `select $1::int4[]`, [][]int32{{1, 2, 3, 4}, {5}, {9, 10, 11, 12}}) + require.Error(t, err, "cannot convert [][]int32 to ArrayGetter because it is a ragged multi-dimensional") + defer rows.Close() +} diff --git a/pgtype/array_getter_setter.go b/pgtype/array_getter_setter.go index 72a6f0e7..2e20f9ec 100644 --- a/pgtype/array_getter_setter.go +++ b/pgtype/array_getter_setter.go @@ -1,10 +1,6 @@ +// Do not edit. Generated from pgtype/array_getter_setter.go.erb package pgtype -import ( - "fmt" - "reflect" -) - type int16Array []int16 func (a int16Array) Dimensions() []ArrayDimension { @@ -19,6 +15,11 @@ func (a int16Array) Index(i int) interface{} { return a[i] } +func (a int16Array) IndexType() interface{} { + var el int16 + return el +} + func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error { if dimensions == nil { a = nil @@ -34,6 +35,10 @@ func (a int16Array) ScanIndex(i int) interface{} { return &a[i] } +func (a int16Array) ScanIndexType() interface{} { + return new(int16) +} + type uint16Array []uint16 func (a uint16Array) Dimensions() []ArrayDimension { @@ -48,6 +53,11 @@ func (a uint16Array) Index(i int) interface{} { return a[i] } +func (a uint16Array) IndexType() interface{} { + var el uint16 + return el +} + func (a *uint16Array) SetDimensions(dimensions []ArrayDimension) error { if dimensions == nil { a = nil @@ -63,81 +73,6 @@ func (a uint16Array) ScanIndex(i int) interface{} { return &a[i] } -type anySliceArray struct { - slice reflect.Value -} - -func (a anySliceArray) Dimensions() []ArrayDimension { - if a.slice.IsNil() { - return nil - } - - return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}} -} - -func (a anySliceArray) Index(i int) interface{} { - return a.slice.Index(i).Interface() -} - -func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error { - sliceType := a.slice.Type() - - if dimensions == nil { - a.slice.Set(reflect.Zero(sliceType)) - return nil - } - - elementCount := cardinality(dimensions) - slice := reflect.MakeSlice(sliceType, elementCount, elementCount) - a.slice.Set(slice) - return nil -} - -func (a anySliceArray) ScanIndex(i int) interface{} { - return a.slice.Index(i).Addr().Interface() -} - -func makeArrayGetter(a interface{}) (ArrayGetter, error) { - switch a := a.(type) { - case ArrayGetter: - return a, nil - - case []int16: - return (*int16Array)(&a), nil - - case []uint16: - return (*uint16Array)(&a), nil - - } - - reflectValue := reflect.ValueOf(a) - if reflectValue.Kind() == reflect.Slice { - return &anySliceArray{slice: reflectValue}, nil - } - - return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a) -} - -func makeArraySetter(a interface{}) (ArraySetter, error) { - switch a := a.(type) { - case ArraySetter: - return a, nil - - case *[]int16: - return (*int16Array)(a), nil - - case *[]uint16: - return (*uint16Array)(a), nil - - } - - value := reflect.ValueOf(a) - if value.Kind() == reflect.Ptr { - elemValue := value.Elem() - if elemValue.Kind() == reflect.Slice { - return &anySliceArray{slice: elemValue}, nil - } - } - - return nil, fmt.Errorf("cannot convert %T to ArraySetter", a) +func (a uint16Array) ScanIndexType() interface{} { + return new(uint16) } diff --git a/pgtype/array_getter_setter.go.erb b/pgtype/array_getter_setter.go.erb index 01b7d4fa..a9d60d35 100644 --- a/pgtype/array_getter_setter.go.erb +++ b/pgtype/array_getter_setter.go.erb @@ -27,6 +27,11 @@ import ( return a[i] } + func (a <%= array_type %>) IndexType() interface{} { + var el <%= element_type %> + return el + } + func (a *<%= array_type %>) SetDimensions(dimensions []ArrayDimension) error { if dimensions == nil { a = nil @@ -41,77 +46,8 @@ import ( func (a <%= array_type %>) ScanIndex(i int) interface{} { return &a[i] } -<% end %> -type anySliceArray struct { - slice reflect.Value -} - -func (a anySliceArray) Dimensions() []ArrayDimension { - if a.slice.IsNil() { - return nil - } - - return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}} -} - -func (a anySliceArray) Index(i int) interface{} { - return a.slice.Index(i).Interface() -} - -func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error { - sliceType := a.slice.Type() - - if dimensions == nil { - a.slice.Set(reflect.Zero(sliceType)) - return nil - } - - elementCount := cardinality(dimensions) - slice := reflect.MakeSlice(sliceType, elementCount, elementCount) - a.slice.Set(slice) - return nil -} - -func (a anySliceArray) ScanIndex(i int) interface{} { - return a.slice.Index(i).Addr().Interface() -} - -func makeArrayGetter(a interface{}) (ArrayGetter, error) { - switch a := a.(type) { - case ArrayGetter: - return a, nil - <% types.each do |array_type, element_type| %> - case []<%= element_type %>: - return (*<%= array_type %>)(&a), nil - <% end %> - } - - reflectValue := reflect.ValueOf(a) - if reflectValue.Kind() == reflect.Slice { - return &anySliceArray{slice: reflectValue}, nil + func (a <%= array_type %>) ScanIndexType() interface{} { + return new(<%= element_type %>) } - - return nil, fmt.Errorf("cannot convert %T to ArrayGetter", a) -} - -func makeArraySetter(a interface{}) (ArraySetter, error) { - switch a := a.(type) { - case ArraySetter: - return a, nil - <% types.each do |array_type, element_type| %> - case *[]<%= element_type %>: - return (*<%= array_type %>)(a), nil - <% end %> - } - - value := reflect.ValueOf(a) - if value.Kind() == reflect.Ptr { - elemValue := value.Elem() - if elemValue.Kind() == reflect.Slice { - return &anySliceArray{slice: elemValue}, nil - } - } - - return nil, fmt.Errorf("cannot convert %T to ArraySetter", a) -} +<% end %> diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 1799de55..466ef45a 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -656,3 +656,168 @@ func (w *ptrStructWrapper) ScanIndex(i int) interface{} { return w.exportedFields[i].Addr().Interface() } + +type anySliceArray struct { + slice reflect.Value +} + +func (a anySliceArray) Dimensions() []ArrayDimension { + if a.slice.IsNil() { + return nil + } + + return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}} +} + +func (a anySliceArray) Index(i int) interface{} { + return a.slice.Index(i).Interface() +} + +func (a anySliceArray) IndexType() interface{} { + return reflect.New(a.slice.Type().Elem()).Elem().Interface() +} + +func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error { + sliceType := a.slice.Type() + + if dimensions == nil { + a.slice.Set(reflect.Zero(sliceType)) + return nil + } + + elementCount := cardinality(dimensions) + slice := reflect.MakeSlice(sliceType, elementCount, elementCount) + a.slice.Set(slice) + return nil +} + +func (a *anySliceArray) ScanIndex(i int) interface{} { + return a.slice.Index(i).Addr().Interface() +} + +func (a *anySliceArray) ScanIndexType() interface{} { + return reflect.New(a.slice.Type().Elem()).Interface() +} + +type anyMultiDimSliceArray struct { + slice reflect.Value + dims []ArrayDimension +} + +func (a *anyMultiDimSliceArray) Dimensions() []ArrayDimension { + if a.slice.IsNil() { + return nil + } + + s := a.slice + for { + a.dims = append(a.dims, ArrayDimension{Length: int32(s.Len()), LowerBound: 1}) + if s.Len() > 0 { + s = s.Index(0) + } else { + break + } + if s.Type().Kind() == reflect.Slice { + } else { + break + } + } + + return a.dims +} + +func (a *anyMultiDimSliceArray) Index(i int) interface{} { + if len(a.dims) == 1 { + return a.slice.Index(i).Interface() + } + + indexes := make([]int, len(a.dims)) + for j := len(a.dims) - 1; j >= 0; j-- { + dimLen := int(a.dims[j].Length) + indexes[j] = i % dimLen + i = i / dimLen + } + + v := a.slice + for _, si := range indexes { + v = v.Index(si) + } + + return v.Interface() +} + +func (a *anyMultiDimSliceArray) IndexType() interface{} { + lowestSliceType := a.slice.Type() + for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { + } + return reflect.New(lowestSliceType.Elem()).Elem().Interface() +} + +func (a *anyMultiDimSliceArray) SetDimensions(dimensions []ArrayDimension) error { + sliceType := a.slice.Type() + + if dimensions == nil { + a.slice.Set(reflect.Zero(sliceType)) + return nil + } + + switch len(dimensions) { + case 0: + return fmt.Errorf("impossible: non-nil dimensions but zero elements") + case 1: + elementCount := cardinality(dimensions) + slice := reflect.MakeSlice(sliceType, elementCount, elementCount) + a.slice.Set(slice) + return nil + default: + sliceDimensionCount := 1 + lowestSliceType := sliceType + for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { + sliceDimensionCount++ + } + + if sliceDimensionCount != len(dimensions) { + return fmt.Errorf("PostgreSQL array has %d dimensions but slice has %d dimensions", len(dimensions), sliceDimensionCount) + } + + elementCount := cardinality(dimensions) + flatSlice := reflect.MakeSlice(lowestSliceType, elementCount, elementCount) + + multiDimSlice := a.makeMultidimensionalSlice(sliceType, dimensions, flatSlice, 0) + a.slice.Set(multiDimSlice) + + // Now that a.slice is a multi-dimensional slice with the underlying data pointed at flatSlice change a.slice to + // flatSlice so ScanIndex only has to handle simple one dimensional slices. + a.slice = flatSlice + + return nil + } + +} + +func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type, dimensions []ArrayDimension, flatSlice reflect.Value, flatSliceIdx int) reflect.Value { + if len(dimensions) == 1 { + endIdx := flatSliceIdx + int(dimensions[0].Length) + return flatSlice.Slice3(flatSliceIdx, endIdx, endIdx) + } + + sliceLen := int(dimensions[0].Length) + slice := reflect.MakeSlice(sliceType, sliceLen, sliceLen) + for i := 0; i < sliceLen; i++ { + subSlice := a.makeMultidimensionalSlice(sliceType.Elem(), dimensions[1:], flatSlice, flatSliceIdx+(i*int(dimensions[1].Length))) + slice.Index(i).Set(subSlice) + } + + return slice +} + +func (a *anyMultiDimSliceArray) ScanIndex(i int) interface{} { + return a.slice.Index(i).Addr().Interface() +} + +func (a *anyMultiDimSliceArray) ScanIndexType() interface{} { + lowestSliceType := a.slice.Type() + for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() { + } + return reflect.New(lowestSliceType.Elem()).Interface() +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 8db5ae3f..54792963 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -204,6 +204,8 @@ func NewConnInfo() *ConnInfo { TryWrapBuiltinTypeEncodePlan, TryWrapFindUnderlyingTypeEncodePlan, TryWrapStructEncodePlan, + TryWrapSliceEncodePlan, + TryWrapMultiDimSliceEncodePlan, }, TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ @@ -211,6 +213,8 @@ func NewConnInfo() *ConnInfo { TryWrapBuiltinTypeScanPlan, TryFindUnderlyingTypeScanPlan, TryWrapStructScanPlan, + TryWrapPtrSliceScanPlan, + TryWrapPtrMultiDimSliceScanPlan, }, } @@ -930,6 +934,62 @@ func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target interface{}) error return plan.next.Scan(src, &w) } +// TryWrapPtrSliceScanPlan tries to wrap a pointer to a single dimension slice. +func TryWrapPtrSliceScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + targetElemValue := targetValue.Elem() + + if targetElemValue.Kind() == reflect.Slice { + return &wrapPtrSliceScanPlan{}, &anySliceArray{slice: targetElemValue}, true + } + return nil, nil, false +} + +type wrapPtrSliceScanPlan struct { + next ScanPlan +} + +func (plan *wrapPtrSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrSliceScanPlan) Scan(src []byte, target interface{}) error { + return plan.next.Scan(src, &anySliceArray{slice: reflect.ValueOf(target).Elem()}) +} + +// TryWrapPtrMultiDimSliceScanPlan tries to wrap a pointer to a multi-dimension slice. +func TryWrapPtrMultiDimSliceScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + targetElemValue := targetValue.Elem() + + if targetElemValue.Kind() == reflect.Slice { + elemElemKind := targetElemValue.Type().Elem().Kind() + if elemElemKind == reflect.Slice { + if !isRagged(targetElemValue) { + return &wrapPtrMultiDimSliceScanPlan{}, &anyMultiDimSliceArray{slice: targetValue.Elem()}, true + } + } + } + + return nil, nil, false +} + +type wrapPtrMultiDimSliceScanPlan struct { + next ScanPlan +} + +func (plan *wrapPtrMultiDimSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrMultiDimSliceScanPlan) Scan(src []byte, target interface{}) error { + return plan.next.Scan(src, &anyMultiDimSliceArray{slice: reflect.ValueOf(target).Elem()}) +} + // PlanScan prepares a plan to scan a value into target. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan { if _, ok := target.(*UndecodedBytes); ok { @@ -1495,6 +1555,63 @@ func getExportedFieldValues(structValue reflect.Value) []reflect.Value { return exportedFields } +func TryWrapSliceEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { + if reflect.TypeOf(value).Kind() == reflect.Slice { + w := anySliceArray{ + slice: reflect.ValueOf(value), + } + return &wrapSliceEncodePlan{}, w, true + } + + return nil, nil, false +} + +type wrapSliceEncodePlan struct { + next EncodePlan +} + +func (plan *wrapSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapSliceEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + w := anySliceArray{ + slice: reflect.ValueOf(value), + } + + return plan.next.Encode(w, buf) +} + +func TryWrapMultiDimSliceEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { + sliceValue := reflect.ValueOf(value) + if sliceValue.Kind() == reflect.Slice { + valueElemType := sliceValue.Type().Elem() + + if valueElemType.Kind() == reflect.Slice { + if !isRagged(sliceValue) { + w := anyMultiDimSliceArray{ + slice: reflect.ValueOf(value), + } + return &wrapMultiDimSliceEncodePlan{}, &w, true + } + } + } + + return nil, nil, false +} + +type wrapMultiDimSliceEncodePlan struct { + next EncodePlan +} + +func (plan *wrapMultiDimSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapMultiDimSliceEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + w := anyMultiDimSliceArray{ + slice: reflect.ValueOf(value), + } + + return plan.next.Encode(&w, buf) +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written.