diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 4e78ce71..86656524 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -23,6 +23,25 @@ func (dst *Int4Array) Set(src interface{}) error { switch value := src.(type) { + case []int: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int32: if value == nil { *dst = Int4Array{Status: Null} diff --git a/pgtype/int4_array_test.go b/pgtype/int4_array_test.go index 602a3657..f0418600 100644 --- a/pgtype/int4_array_test.go +++ b/pgtype/int4_array_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "math" "reflect" "testing" @@ -54,8 +55,9 @@ func TestInt4ArrayTranscode(t *testing.T) { func TestInt4ArraySet(t *testing.T) { successfulTests := []struct { - source interface{} - result pgtype.Int4Array + source interface{} + result pgtype.Int4Array + expectedError bool }{ { source: []int32{1}, @@ -64,6 +66,17 @@ func TestInt4ArraySet(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, + { + source: []int{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int{1, math.MaxInt32 + 1, 2}, + expectedError: true, + }, { source: []uint32{1}, result: pgtype.Int4Array{ @@ -81,9 +94,17 @@ func TestInt4ArraySet(t *testing.T) { var r pgtype.Int4Array err := r.Set(tt.source) if err != nil { + if tt.expectedError { + continue + } t.Errorf("%d: %v", i, err) } + if tt.expectedError { + t.Errorf("%d: an error was expected, %v", i, tt) + continue + } + if !reflect.DeepEqual(r, tt.result) { t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) }