diff --git a/Rakefile b/Rakefile index de174fae..d957573e 100644 --- a/Rakefile +++ b/Rakefile @@ -7,7 +7,6 @@ rule '.go' => '.go.erb' do |task| end generated_code_files = [ - "pgtype/array_getter_setter.go", "pgtype/int.go", "pgtype/int_test.go", "pgtype/integration_benchmark_test.go", diff --git a/pgtype/array.go b/pgtype/array.go index 8de2b4dd..d34a94e5 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -394,3 +394,88 @@ func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, e } return dimensions, elementsLength, true } + +// Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves +// PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed. +type Array[T any] struct { + Elements []T + Dims []ArrayDimension + Valid bool +} + +func (a Array[T]) Dimensions() []ArrayDimension { + return a.Dims +} + +func (a Array[T]) Index(i int) any { + return a.Elements[i] +} + +func (a Array[T]) IndexType() any { + var el T + return el +} + +func (a *Array[T]) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + *a = Array[T]{} + return nil + } + + elementCount := cardinality(dimensions) + *a = Array[T]{ + Elements: make([]T, elementCount), + Dims: dimensions, + Valid: true, + } + + return nil +} + +func (a Array[T]) ScanIndex(i int) any { + return &a.Elements[i] +} + +func (a Array[T]) ScanIndexType() any { + return new(T) +} + +// FlatArray implements the ArrayGetter and ArraySetter interfaces for any slice of T. It ignores PostgreSQL dimensions +// and custom lower bounds. Use Array to preserve these. +type FlatArray[T any] []T + +func (a FlatArray[T]) Dimensions() []ArrayDimension { + if a == nil { + return nil + } + + return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} +} + +func (a FlatArray[T]) Index(i int) any { + return a[i] +} + +func (a FlatArray[T]) IndexType() any { + var el T + return el +} + +func (a *FlatArray[T]) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + a = nil + return nil + } + + elementCount := cardinality(dimensions) + *a = make(FlatArray[T], elementCount) + return nil +} + +func (a FlatArray[T]) ScanIndex(i int) any { + return &a[i] +} + +func (a FlatArray[T]) ScanIndexType() any { + return new(T) +} diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 379a9096..8aab13bb 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -23,9 +23,9 @@ type ArrayGetter interface { // ArraySetter is a type can be set from a PostgreSQL array. type ArraySetter interface { - // SetDimensions prepares the value such that ScanIndex can be called for each element. dimensions may be nil to - // indicate a NULL array. If unable to exactly preserve dimensions SetDimensions may return an error or silently - // flatten the array dimensions. + // SetDimensions prepares the value such that ScanIndex can be called for each element. This will remove any existing + // elements. dimensions may be nil to indicate a NULL array. If unable to exactly preserve dimensions SetDimensions + // may return an error or silently flatten the array dimensions. SetDimensions(dimensions []ArrayDimension) error // ScanIndex returns a value usable as a scan target for i. SetDimensions must be called before ScanIndex. diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index e4c00d1e..65289d04 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -5,6 +5,7 @@ import ( "testing" pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -47,6 +48,53 @@ func TestArrayCodec(t *testing.T) { }) } +func TestArrayCodecFlatArray(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for i, tt := range []struct { + expected any + }{ + {pgtype.FlatArray[int32](nil)}, + {pgtype.FlatArray[int32]{}}, + {pgtype.FlatArray[int32]{1, 2, 3}}, + } { + var actual pgtype.FlatArray[int32] + err := conn.QueryRow( + ctx, + "select $1::int[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) +} + +func TestArrayCodecArray(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + for i, tt := range []struct { + expected any + }{ + {pgtype.Array[int32]{ + Elements: []int32{1, 2, 3, 4}, + Dims: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 2}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }}, + } { + var actual pgtype.Array[int32] + err := conn.QueryRow( + ctx, + "select $1::int[]", + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + }) +} + func TestArrayCodecAnySlice(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { type _int16Slice []int16 diff --git a/pgtype/array_getter_setter.go b/pgtype/array_getter_setter.go deleted file mode 100644 index b0c6b505..00000000 --- a/pgtype/array_getter_setter.go +++ /dev/null @@ -1,78 +0,0 @@ -// Do not edit. Generated from pgtype/array_getter_setter.go.erb -package pgtype - -type int16Array []int16 - -func (a int16Array) Dimensions() []ArrayDimension { - if a == nil { - return nil - } - - return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} -} - -func (a int16Array) Index(i int) any { - return a[i] -} - -func (a int16Array) IndexType() any { - var el int16 - return el -} - -func (a *int16Array) SetDimensions(dimensions []ArrayDimension) error { - if dimensions == nil { - a = nil - return nil - } - - elementCount := cardinality(dimensions) - *a = make(int16Array, elementCount) - return nil -} - -func (a int16Array) ScanIndex(i int) any { - return &a[i] -} - -func (a int16Array) ScanIndexType() any { - return new(int16) -} - -type uint16Array []uint16 - -func (a uint16Array) Dimensions() []ArrayDimension { - if a == nil { - return nil - } - - return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} -} - -func (a uint16Array) Index(i int) any { - return a[i] -} - -func (a uint16Array) IndexType() any { - var el uint16 - return el -} - -func (a *uint16Array) SetDimensions(dimensions []ArrayDimension) error { - if dimensions == nil { - a = nil - return nil - } - - elementCount := cardinality(dimensions) - *a = make(uint16Array, elementCount) - return nil -} - -func (a uint16Array) ScanIndex(i int) any { - return &a[i] -} - -func (a uint16Array) ScanIndexType() any { - return new(uint16) -} diff --git a/pgtype/array_getter_setter.go.erb b/pgtype/array_getter_setter.go.erb deleted file mode 100644 index 1c8cdff4..00000000 --- a/pgtype/array_getter_setter.go.erb +++ /dev/null @@ -1,53 +0,0 @@ -package pgtype - -import ( - "fmt" - "reflect" -) - -<% - types = [ - ["int16Array", "int16"], - ["uint16Array", "uint16"], - ] -%> - -<% types.each do |array_type, element_type| %> - type <%= array_type %> []<%= element_type %> - - func (a <%= array_type %>) Dimensions() []ArrayDimension { - if a == nil { - return nil - } - - return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} - } - - func (a <%= array_type %>) Index(i int) any { - return a[i] - } - - func (a <%= array_type %>) IndexType() any { - var el <%= element_type %> - return el - } - - func (a *<%= array_type %>) SetDimensions(dimensions []ArrayDimension) error { - if dimensions == nil { - a = nil - return nil - } - - elementCount := cardinality(dimensions) - *a = make(<%= array_type %>, elementCount) - return nil - } - - func (a <%= array_type %>) ScanIndex(i int) any { - return &a[i] - } - - func (a <%= array_type %>) ScanIndexType() any { - return new(<%= element_type %>) - } -<% end %> diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index b385b80a..da9cf0bb 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -637,11 +637,11 @@ func (w *ptrStructWrapper) ScanIndex(i int) any { return w.exportedFields[i].Addr().Interface() } -type anySliceArray struct { +type anySliceArrayReflect struct { slice reflect.Value } -func (a anySliceArray) Dimensions() []ArrayDimension { +func (a anySliceArrayReflect) Dimensions() []ArrayDimension { if a.slice.IsNil() { return nil } @@ -649,15 +649,15 @@ func (a anySliceArray) Dimensions() []ArrayDimension { return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}} } -func (a anySliceArray) Index(i int) any { +func (a anySliceArrayReflect) Index(i int) any { return a.slice.Index(i).Interface() } -func (a anySliceArray) IndexType() any { +func (a anySliceArrayReflect) IndexType() any { return reflect.New(a.slice.Type().Elem()).Elem().Interface() } -func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error { +func (a *anySliceArrayReflect) SetDimensions(dimensions []ArrayDimension) error { sliceType := a.slice.Type() if dimensions == nil { @@ -671,11 +671,11 @@ func (a *anySliceArray) SetDimensions(dimensions []ArrayDimension) error { return nil } -func (a *anySliceArray) ScanIndex(i int) any { +func (a *anySliceArrayReflect) ScanIndex(i int) any { return a.slice.Index(i).Addr().Interface() } -func (a *anySliceArray) ScanIndexType() any { +func (a *anySliceArrayReflect) ScanIndexType() any { return reflect.New(a.slice.Type().Elem()).Interface() } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index e35299e5..db916220 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -993,6 +993,24 @@ func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target any) error { // TryWrapPtrSliceScanPlan tries to wrap a pointer to a single dimension slice. func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { + // Avoid using reflect path for common types. + switch target := target.(type) { + case *[]int16: + return &wrapPtrSliceScanPlan[int16]{}, (*FlatArray[int16])(target), true + case *[]int32: + return &wrapPtrSliceScanPlan[int32]{}, (*FlatArray[int32])(target), true + case *[]int64: + return &wrapPtrSliceScanPlan[int64]{}, (*FlatArray[int64])(target), true + case *[]float32: + return &wrapPtrSliceScanPlan[float32]{}, (*FlatArray[float32])(target), true + case *[]float64: + return &wrapPtrSliceScanPlan[float64]{}, (*FlatArray[float64])(target), true + case *[]string: + return &wrapPtrSliceScanPlan[string]{}, (*FlatArray[string])(target), true + case *[]time.Time: + return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true + } + targetValue := reflect.ValueOf(target) if targetValue.Kind() != reflect.Ptr { return nil, nil, false @@ -1001,19 +1019,29 @@ func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextVa targetElemValue := targetValue.Elem() if targetElemValue.Kind() == reflect.Slice { - return &wrapPtrSliceScanPlan{}, &anySliceArray{slice: targetElemValue}, true + return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: targetElemValue}, true } return nil, nil, false } -type wrapPtrSliceScanPlan struct { +type wrapPtrSliceScanPlan[T any] struct { next ScanPlan } -func (plan *wrapPtrSliceScanPlan) SetNext(next ScanPlan) { plan.next = next } +func (plan *wrapPtrSliceScanPlan[T]) SetNext(next ScanPlan) { plan.next = next } -func (plan *wrapPtrSliceScanPlan) Scan(src []byte, target any) error { - return plan.next.Scan(src, &anySliceArray{slice: reflect.ValueOf(target).Elem()}) +func (plan *wrapPtrSliceScanPlan[T]) Scan(src []byte, target any) error { + return plan.next.Scan(src, (*FlatArray[T])(target.(*[]T))) +} + +type wrapPtrSliceReflectScanPlan struct { + next ScanPlan +} + +func (plan *wrapPtrSliceReflectScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrSliceReflectScanPlan) Scan(src []byte, target any) error { + return plan.next.Scan(src, &anySliceArrayReflect{slice: reflect.ValueOf(target).Elem()}) } // TryWrapPtrMultiDimSliceScanPlan tries to wrap a pointer to a multi-dimension slice. @@ -1660,24 +1688,56 @@ func getExportedFieldValues(structValue reflect.Value) []reflect.Value { } func TryWrapSliceEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + // Avoid using reflect path for common types. + switch value := value.(type) { + case []int16: + return &wrapSliceEncodePlan[int16]{}, (FlatArray[int16])(value), true + case []int32: + return &wrapSliceEncodePlan[int32]{}, (FlatArray[int32])(value), true + case []int64: + return &wrapSliceEncodePlan[int64]{}, (FlatArray[int64])(value), true + case []float32: + return &wrapSliceEncodePlan[float32]{}, (FlatArray[float32])(value), true + case []float64: + return &wrapSliceEncodePlan[float64]{}, (FlatArray[float64])(value), true + case []string: + return &wrapSliceEncodePlan[string]{}, (FlatArray[string])(value), true + case []time.Time: + return &wrapSliceEncodePlan[time.Time]{}, (FlatArray[time.Time])(value), true + } + if reflect.TypeOf(value).Kind() == reflect.Slice { - w := anySliceArray{ + w := anySliceArrayReflect{ slice: reflect.ValueOf(value), } - return &wrapSliceEncodePlan{}, w, true + return &wrapSliceEncodeReflectPlan{}, w, true } return nil, nil, false } -type wrapSliceEncodePlan struct { +type wrapSliceEncodePlan[T any] struct { next EncodePlan } -func (plan *wrapSliceEncodePlan) SetNext(next EncodePlan) { plan.next = next } +func (plan *wrapSliceEncodePlan[T]) SetNext(next EncodePlan) { plan.next = next } -func (plan *wrapSliceEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) { - w := anySliceArray{ +func (plan *wrapSliceEncodePlan[T]) Encode(value any, buf []byte) (newBuf []byte, err error) { + w := anySliceArrayReflect{ + slice: reflect.ValueOf(value), + } + + return plan.next.Encode(w, buf) +} + +type wrapSliceEncodeReflectPlan struct { + next EncodePlan +} + +func (plan *wrapSliceEncodeReflectPlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapSliceEncodeReflectPlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + w := anySliceArrayReflect{ slice: reflect.ValueOf(value), }