diff --git a/pgtype/CHANGELOG.md b/pgtype/CHANGELOG.md new file mode 100644 index 00000000..e34c7979 --- /dev/null +++ b/pgtype/CHANGELOG.md @@ -0,0 +1,121 @@ +# 1.9.1 (November 28, 2021) + +* Fix: binary timestamp is assumed to be in UTC (restored behavior changed in v1.9.0) + +# 1.9.0 (November 20, 2021) + +* Fix binary hstore null decoding +* Add shopspring/decimal.NullDecimal support to integration (Eli Treuherz) +* Inet.Set supports bare IP address (Carl Dunham) +* Add zeronull.Float8 +* Fix NULL being lost when scanning unknown OID into sql.Scanner +* Fix BPChar.AssignTo **rune +* Add support for fmt.Stringer and driver.Valuer in String fields encoding (Jan Dubsky) +* Fix really big timestamp(tz)s binary format parsing (e.g. year 294276) (Jim Tsao) +* Support `map[string]*string` as hstore (Adrian Sieger) +* Fix parsing text array with negative bounds +* Add infinity support for numeric (Jim Tsao) + +# 1.8.1 (July 24, 2021) + +* Cleaned up Go module dependency chain + +# 1.8.0 (July 10, 2021) + +* Maintain host bits for inet types (Cameron Daniel) +* Support pointers of wrapping structs (Ivan Daunis) +* Register JSONBArray at NewConnInfo() (Rueian) +* CompositeTextScanner handles backslash escapes + +# 1.7.0 (March 25, 2021) + +* Fix scanning int into **sql.Scanner implementor +* Add tsrange array type (Vasilii Novikov) +* Fix: escaped strings when they start or end with a newline char (Stephane Martin) +* Accept nil *time.Time in Time.Set +* Fix numeric NaN support +* Use Go 1.13 errors instead of xerrors + +# 1.6.2 (December 3, 2020) + +* Fix panic on assigning empty array to non-slice or array +* Fix text array parsing disambiguates NULL and "NULL" +* Fix Timestamptz.DecodeText with too short text + +# 1.6.1 (October 31, 2020) + +* Fix simple protocol empty array support + +# 1.6.0 (October 24, 2020) + +* Fix AssignTo pointer to pointer to slice and named types. +* Fix zero length array assignment (Simo Haasanen) +* Add float64, float32 convert to int2, int4, int8 (lqu3j) +* Support setting infinite timestamps (Erik Agsjö) +* Polygon improvements (duohedron) +* Fix Inet.Set with nil (Tomas Volf) + +# 1.5.0 (September 26, 2020) + +* Add slice of slice mapping to multi-dimensional arrays (Simo Haasanen) +* Fix JSONBArray +* Fix selecting empty array +* Text formatted values except bytea can be directly scanned to []byte +* Add JSON marshalling for UUID (bakmataliev) +* Improve point type conversions (bakmataliev) + +# 1.4.2 (July 22, 2020) + +* Fix encoding of a large composite data type (Yaz Saito) + +# 1.4.1 (July 14, 2020) + +* Fix ArrayType DecodeBinary empty array breaks future reads + +# 1.4.0 (June 27, 2020) + +* Add JSON support to ext/gofrs-uuid +* Performance improvements in Scan path +* Improved ext/shopspring-numeric binary decoding performance +* Add composite type support (Maxim Ivanov and Jack Christensen) +* Add better generic enum type support +* Add generic array type support +* Clarify and normalize Value semantics +* Fix hstore with empty string values +* Numeric supports NaN values (leighhopcroft) +* Add slice of pointer support to array types (megaturbo) +* Add jsonb array type (tserakhau) +* Allow converting intervals with months and days to duration + +# 1.3.0 (March 30, 2020) + +* Get implemented on T instead of *T +* Set will call Get on src if possible +* Range types Set method supports its own type, string, and nil +* Date.Set parses string +* Fix correct format verb for unknown type error (Robert Welin) +* Truncate nanoseconds in EncodeText for Timestamptz and Timestamp + +# 1.2.0 (February 5, 2020) + +* Add zeronull package for easier NULL <-> zero conversion +* Add JSON marshalling for shopspring-numeric extension +* Add JSON marshalling for Bool, Date, JSON/B, Timestamptz (Jeffrey Stiles) +* Fix null status in UnmarshalJSON for some types (Jeffrey Stiles) + +# 1.1.0 (January 11, 2020) + +* Add PostgreSQL time type support +* Add more automatic conversions of integer arrays of different types (Jean-Philippe Quéméner) + +# 1.0.3 (November 16, 2019) + +* Support initializing Array types from a slice of the value (Alex Gaynor) + +# 1.0.2 (October 22, 2019) + +* Fix scan into null into pointer to pointer implementing Decode* interface. (Jeremy Altavilla) + +# 1.0.1 (September 19, 2019) + +* Fix daterange OID diff --git a/pgtype/LICENSE b/pgtype/LICENSE new file mode 100644 index 00000000..5c486c39 --- /dev/null +++ b/pgtype/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2013-2021 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/pgtype/README.md b/pgtype/README.md new file mode 100644 index 00000000..bc4e72f9 --- /dev/null +++ b/pgtype/README.md @@ -0,0 +1,8 @@ +[![](https://godoc.org/github.com/jackc/pgtype?status.svg)](https://godoc.org/github.com/jackc/pgtype) +![CI](https://github.com/jackc/pgtype/workflows/CI/badge.svg) + +# pgtype + +pgtype implements Go types for over 70 PostgreSQL types. pgtype is the type system underlying the +https://github.com/jackc/pgx PostgreSQL driver. These types support the binary format for enhanced performance with pgx. +They also support the database/sql `Scan` and `Value` interfaces. diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go new file mode 100644 index 00000000..0c1f23b5 --- /dev/null +++ b/pgtype/aclitem.go @@ -0,0 +1,127 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem +// might look like this: +// +// postgres=arwdDxt/postgres +// +// Note, however, that because the user/role name part of an aclitem is +// an identifier, it follows all the usual formatting rules for SQL +// identifiers: if it contains spaces and other special characters, +// it should appear in double-quotes: +// +// postgres=arwdDxt/"role with spaces" +// +type ACLItem struct { + String string + Valid bool +} + +func (dst *ACLItem) Set(src interface{}) error { + if src == nil { + *dst = ACLItem{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case string: + *dst = ACLItem{String: value, Valid: true} + case *string: + if value == nil { + *dst = ACLItem{} + } else { + *dst = ACLItem{String: *value, Valid: true} + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to ACLItem", value) + } + + return nil +} + +func (dst ACLItem) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.String +} + +func (src *ACLItem) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *string: + *v = src.String + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = ACLItem{} + return nil + } + + *dst = ACLItem{String: string(src), Valid: true} + return nil +} + +func (src ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, src.String...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *ACLItem) Scan(src interface{}) error { + if src == nil { + *dst = ACLItem{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src ACLItem) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + return src.String, nil +} diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go new file mode 100644 index 00000000..fc1128b7 --- /dev/null +++ b/pgtype/aclitem_array.go @@ -0,0 +1,418 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "fmt" + "reflect" +) + +type ACLItemArray struct { + Elements []ACLItem + Dimensions []ArrayDimension + Valid bool +} + +func (dst *ACLItemArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = ACLItemArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []string: + if value == nil { + *dst = ACLItemArray{} + } else if len(value) == 0 { + *dst = ACLItemArray{Valid: true} + } else { + elements := make([]ACLItem, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = ACLItemArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*string: + if value == nil { + *dst = ACLItemArray{} + } else if len(value) == 0 { + *dst = ACLItemArray{Valid: true} + } else { + elements := make([]ACLItem, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = ACLItemArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []ACLItem: + if value == nil { + *dst = ACLItemArray{} + } else if len(value) == 0 { + *dst = ACLItemArray{Valid: true} + } else { + *dst = ACLItemArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = ACLItemArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for ACLItemArray", src) + } + if elementsLength == 0 { + *dst = ACLItemArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to ACLItemArray", src) + } + + *dst = ACLItemArray{ + Elements: make([]ACLItem, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]ACLItem, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to ACLItemArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *ACLItemArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to ACLItemArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in ACLItemArray", err) + } + index++ + + return index, nil +} + +func (dst ACLItemArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *ACLItemArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from ACLItemArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from ACLItemArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = ACLItemArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []ACLItem + + if len(uta.Elements) > 0 { + elements = make([]ACLItem, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem ACLItem + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (src ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *ACLItemArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src ACLItemArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/aclitem_array_test.go b/pgtype/aclitem_array_test.go new file mode 100644 index 00000000..0d6adb1d --- /dev/null +++ b/pgtype/aclitem_array_test.go @@ -0,0 +1,329 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestACLItemArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "aclitem[]", []interface{}{ + &pgtype.ACLItemArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.ACLItemArray{}, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + //{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}, + {String: `postgres=arwdDxt/postgres`, Valid: true}, // todo: remove after fixing above case + {String: "=r/postgres", Valid: true}, + {}, + {String: "=r/postgres", Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestACLItemArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ACLItemArray + }{ + { + source: []string{"=r/postgres"}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]string)(nil)), + result: pgtype.ACLItemArray{}, + }, + { + source: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.ACLItemArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestACLItemArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string + + simpleTests := []struct { + src pgtype.ACLItemArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &stringSlice, + expected: []string{"=r/postgres"}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &namedStringSlice, + expected: _stringSlice{"=r/postgres"}, + }, + { + src: pgtype.ACLItemArray{}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + { + src: pgtype.ACLItemArray{Valid: true}, + dst: &stringSlice, + expected: []string{}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &stringSliceDim2, + expected: [][]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &stringSliceDim4, + expected: [][][][]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &stringArrayDim2, + expected: [2][1]string{{"=r/postgres"}, {"postgres=arwdDxt/postgres"}}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}, + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{ + {{{ + "=r/postgres", + "postgres=arwdDxt/postgres", + "=r/postgres"}}}, + {{{ + "postgres=arwdDxt/postgres", + "=r/postgres", + "postgres=arwdDxt/postgres"}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.ACLItemArray + dst interface{} + }{ + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &stringSlice, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &stringArrayDim2, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &stringSlice, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + {String: "=r/postgres", Valid: true}, + {String: "postgres=arwdDxt/postgres", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &stringArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go new file mode 100644 index 00000000..4e9bc5b0 --- /dev/null +++ b/pgtype/aclitem_test.go @@ -0,0 +1,97 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestACLItemTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ + &pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Valid: true}, + //&pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}, + &pgtype.ACLItem{}, + }) +} + +func TestACLItemSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ACLItem + }{ + {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Valid: true}}, + {source: (*string)(nil), result: pgtype.ACLItem{}}, + } + + for i, tt := range successfulTests { + var d pgtype.ACLItem + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestACLItemAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.ACLItem + dst interface{} + expected interface{} + }{ + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Valid: true}, dst: &s, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.ACLItem{}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.ACLItem + dst interface{} + expected interface{} + }{ + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Valid: true}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.ACLItem + dst interface{} + }{ + {src: pgtype.ACLItem{}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/array.go b/pgtype/array.go new file mode 100644 index 00000000..174007c1 --- /dev/null +++ b/pgtype/array.go @@ -0,0 +1,381 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "reflect" + "strconv" + "strings" + "unicode" + + "github.com/jackc/pgio" +) + +// Information on the internals of PostgreSQL arrays can be found in +// src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of +// particular interest is the array_send function. + +type ArrayHeader struct { + ContainsNull bool + ElementOID int32 + Dimensions []ArrayDimension +} + +type ArrayDimension struct { + Length int32 + LowerBound int32 +} + +func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { + if len(src) < 12 { + return 0, fmt.Errorf("array header too short: %d", len(src)) + } + + rp := 0 + + numDims := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 + rp += 4 + + dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + if numDims > 0 { + dst.Dimensions = make([]ArrayDimension, numDims) + } + if len(src) < 12+numDims*8 { + return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) + } + for i := range dst.Dimensions { + dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + dst.Dimensions[i].LowerBound = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + } + + return rp, nil +} + +func (src ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { + buf = pgio.AppendInt32(buf, int32(len(src.Dimensions))) + + var containsNull int32 + if src.ContainsNull { + containsNull = 1 + } + buf = pgio.AppendInt32(buf, containsNull) + + buf = pgio.AppendInt32(buf, src.ElementOID) + + for i := range src.Dimensions { + buf = pgio.AppendInt32(buf, src.Dimensions[i].Length) + buf = pgio.AppendInt32(buf, src.Dimensions[i].LowerBound) + } + + return buf +} + +type UntypedTextArray struct { + Elements []string + Quoted []bool + Dimensions []ArrayDimension +} + +func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { + dst := &UntypedTextArray{} + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + var explicitDimensions []ArrayDimension + + // Array has explicit dimensions + if r == '[' { + buf.UnreadRune() + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '=' { + break + } else if r != '[' { + return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r) + } + + lower, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ':' { + return nil, fmt.Errorf("invalid array, expected ':' got %v", r) + } + + upper, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ']' { + return nil, fmt.Errorf("invalid array, expected ']' got %v", r) + } + + explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + } + + if r != '{' { + return nil, fmt.Errorf("invalid array, expected '{': %v", err) + } + + implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} + + // Consume all initial opening brackets. This provides number of dimensions. + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '{' { + implicitDimensions[len(implicitDimensions)-1].Length = 1 + implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1}) + } else { + buf.UnreadRune() + break + } + } + currentDim := len(implicitDimensions) - 1 + counterDim := currentDim + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + switch r { + case '{': + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + currentDim++ + case ',': + case '}': + currentDim-- + if currentDim < counterDim { + counterDim = currentDim + } + default: + buf.UnreadRune() + value, quoted, err := arrayParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid array value: %v", err) + } + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + dst.Quoted = append(dst.Quoted, quoted) + dst.Elements = append(dst.Elements, value) + } + + if currentDim < 0 { + break + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + if len(dst.Elements) == 0 { + dst.Dimensions = nil + } else if len(explicitDimensions) > 0 { + dst.Dimensions = explicitDimensions + } else { + dst.Dimensions = implicitDimensions + } + + return dst, nil +} + +func skipWhitespace(buf *bytes.Buffer) { + var r rune + var err error + for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() { + } + + if err != io.EOF { + buf.UnreadRune() + } +} + +func arrayParseValue(buf *bytes.Buffer) (string, bool, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", false, err + } + if r == '"' { + return arrayParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", false, err + } + + switch r { + case ',', '}': + buf.UnreadRune() + return s.String(), false, nil + } + + s.WriteRune(r) + } +} + +func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", false, err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", false, err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", false, err + } + buf.UnreadRune() + return s.String(), true, nil + } + s.WriteRune(r) + } +} + +func arrayParseInteger(buf *bytes.Buffer) (int32, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return 0, err + } + + if ('0' <= r && r <= '9') || r == '-' { + s.WriteRune(r) + } else { + buf.UnreadRune() + n, err := strconv.ParseInt(s.String(), 10, 32) + if err != nil { + return 0, err + } + return int32(n), nil + } + } +} + +func EncodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte { + var customDimensions bool + for _, dim := range dimensions { + if dim.LowerBound != 1 { + customDimensions = true + } + } + + if !customDimensions { + return buf + } + + for _, dim := range dimensions { + buf = append(buf, '[') + buf = append(buf, strconv.FormatInt(int64(dim.LowerBound), 10)...) + buf = append(buf, ':') + buf = append(buf, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)...) + buf = append(buf, ']') + } + + return append(buf, '=') +} + +var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteArrayElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func isSpace(ch byte) bool { + // see https://github.com/postgres/postgres/blob/REL_12_STABLE/src/backend/parser/scansup.c#L224 + return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\f' +} + +func QuoteArrayElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { + return quoteArrayElement(src) + } + return src +} + +func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, elementsLength int) ([]ArrayDimension, int, bool) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + length := value.Len() + if 0 == elementsLength { + elementsLength = length + } else { + elementsLength *= length + } + dimensions = append(dimensions, ArrayDimension{Length: int32(length), LowerBound: 1}) + for i := 0; i < length; i++ { + if d, l, ok := findDimensionsFromValue(value.Index(i), dimensions, elementsLength); ok { + return d, l, true + } + } + } + return dimensions, elementsLength, true +} diff --git a/pgtype/array_test.go b/pgtype/array_test.go new file mode 100644 index 00000000..f1fe90f4 --- /dev/null +++ b/pgtype/array_test.go @@ -0,0 +1,135 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/stretchr/testify/require" +) + +func TestParseUntypedTextArray(t *testing.T) { + tests := []struct { + source string + result pgtype.UntypedTextArray + }{ + { + source: "{}", + result: pgtype.UntypedTextArray{ + Elements: nil, + Quoted: nil, + Dimensions: nil, + }, + }, + { + source: "{1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Quoted: []bool{false}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{a,b}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b"}, + Quoted: []bool{false, false}, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + }, + }, + { + source: `{"NULL"}`, + result: pgtype.UntypedTextArray{ + Elements: []string{"NULL"}, + Quoted: []bool{true}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: `{""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{""}, + Quoted: []bool{true}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: `{"He said, \"Hello.\""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{`He said, "Hello."`}, + Quoted: []bool{true}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{{a,b},{c,d},{e,f}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f"}, + Quoted: []bool{false, false, false, false, false, false}, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + }, + }, + { + source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"}, + Quoted: []bool{false, false, false, false, false, false, false, false, false, false, false, false}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 1}, + {Length: 3, LowerBound: 1}, + {Length: 2, LowerBound: 1}, + }, + }, + }, + { + source: "[4:4]={1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Quoted: []bool{false}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 4}}, + }, + }, + { + source: "[4:5][2:3]={{a,b},{c,d}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d"}, + Quoted: []bool{false, false, false, false}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + }, + }, + { + source: "[-4:-2]={1,2,3}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1", "2", "3"}, + Quoted: []bool{false, false, false}, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: -4}}, + }, + }, + } + + for i, tt := range tests { + r, err := pgtype.ParseUntypedTextArray(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(*r, tt.result) { + t.Errorf("%d: expected %+v to be parsed to %+v, but it was %+v", i, tt.source, tt.result, *r) + } + } +} + +// https://github.com/jackc/pgx/issues/881 +func TestArrayAssignToEmptyToNonSlice(t *testing.T) { + var a pgtype.Int4Array + err := a.Set([]int32{}) + require.NoError(t, err) + + var iface interface{} + err = a.AssignTo(&iface) + require.EqualError(t, err, "cannot assign *pgtype.Int4Array to *interface {}") +} diff --git a/pgtype/array_type.go b/pgtype/array_type.go new file mode 100644 index 00000000..c4f162af --- /dev/null +++ b/pgtype/array_type.go @@ -0,0 +1,368 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +// ArrayType represents an array type. While it implements Value, this is only in service of its type conversion duties +// when registered as a data type in a ConnType. It should not be used directly as a Value. ArrayType is a convenience +// type for types that do not have an concrete array type. +type ArrayType struct { + elements []ValueTranscoder + dimensions []ArrayDimension + + typeName string + newElement func() ValueTranscoder + + elementOID uint32 + valid bool +} + +func NewArrayType(typeName string, elementOID uint32, newElement func() ValueTranscoder) *ArrayType { + return &ArrayType{typeName: typeName, elementOID: elementOID, newElement: newElement} +} + +func (at *ArrayType) NewTypeValue() Value { + return &ArrayType{ + elements: at.elements, + dimensions: at.dimensions, + valid: at.valid, + + typeName: at.typeName, + elementOID: at.elementOID, + newElement: at.newElement, + } +} + +func (at *ArrayType) TypeName() string { + return at.typeName +} + +func (dst *ArrayType) setNil() { + dst.elements = nil + dst.dimensions = nil + dst.valid = false +} + +func (dst *ArrayType) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + dst.setNil() + return nil + } + + sliceVal := reflect.ValueOf(src) + if sliceVal.Kind() != reflect.Slice { + return fmt.Errorf("cannot set non-slice") + } + + if sliceVal.IsNil() { + dst.setNil() + return nil + } + + dst.elements = make([]ValueTranscoder, sliceVal.Len()) + for i := range dst.elements { + v := dst.newElement() + err := v.Set(sliceVal.Index(i).Interface()) + if err != nil { + return err + } + + dst.elements[i] = v + } + dst.dimensions = []ArrayDimension{{Length: int32(len(dst.elements)), LowerBound: 1}} + dst.valid = true + + return nil +} + +func (src ArrayType) Get() interface{} { + if !src.valid { + return nil + } + + elementValues := make([]interface{}, len(src.elements)) + for i := range src.elements { + elementValues[i] = src.elements[i].Get() + } + return elementValues +} + +func (src *ArrayType) AssignTo(dst interface{}) error { + ptrSlice := reflect.ValueOf(dst) + if ptrSlice.Kind() != reflect.Ptr { + return fmt.Errorf("cannot assign to non-pointer") + } + + sliceVal := ptrSlice.Elem() + sliceType := sliceVal.Type() + + if sliceType.Kind() != reflect.Slice { + return fmt.Errorf("cannot assign to pointer to non-slice") + } + + if src.valid { + slice := reflect.MakeSlice(sliceType, len(src.elements), len(src.elements)) + elemType := sliceType.Elem() + + for i := range src.elements { + ptrElem := reflect.New(elemType) + err := src.elements[i].AssignTo(ptrElem.Interface()) + if err != nil { + return err + } + + slice.Index(i).Set(ptrElem.Elem()) + } + + sliceVal.Set(slice) + return nil + } else { + sliceVal.Set(reflect.Zero(sliceType)) + return nil + } +} + +func (ArrayType) BinaryFormatSupported() bool { + return true +} + +func (ArrayType) TextFormatSupported() bool { + return true +} + +func (ArrayType) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *ArrayType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + if src == nil { + dst.setNil() + return nil + } + + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src ArrayType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} + +func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []ValueTranscoder + + if len(uta.Elements) > 0 { + elements = make([]ValueTranscoder, len(uta.Elements)) + + for i, s := range uta.Elements { + elem := dst.newElement() + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeResult(ci, dst.elementOID, TextFormatCode, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + dst.elements = elements + dst.dimensions = uta.Dimensions + dst.valid = true + + return nil +} + +func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error { + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + var elements []ValueTranscoder + + if len(arrayHeader.Dimensions) == 0 { + dst.elements = elements + dst.dimensions = arrayHeader.Dimensions + dst.valid = true + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements = make([]ValueTranscoder, elementCount) + + for i := range elements { + elem := dst.newElement() + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elem.DecodeResult(ci, dst.elementOID, BinaryFormatCode, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + + dst.elements = elements + dst.dimensions = arrayHeader.Dimensions + dst.valid = true + + return nil +} + +func (src ArrayType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.valid { + return nil, nil + } + + if len(src.dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.dimensions)) + dimElemCounts[len(src.dimensions)-1] = int(src.dimensions[len(src.dimensions)-1].Length) + for i := len(src.dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeParam(ci, src.elementOID, TextFormatCode, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src ArrayType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.dimensions, + ElementOID: int32(src.elementOID), + } + + for i := range src.elements { + if src.elements[i].Get() == nil { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.elements[i].EncodeParam(ci, src.elementOID, BinaryFormatCode, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *ArrayType) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src ArrayType) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/array_type_test.go b/pgtype/array_type_test.go new file mode 100644 index 00000000..626df4dc --- /dev/null +++ b/pgtype/array_type_test.go @@ -0,0 +1,84 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestArrayTypeValue(t *testing.T) { + arrayType := pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }) + + err := arrayType.Set(nil) + require.NoError(t, err) + + gotValue := arrayType.Get() + require.Nil(t, gotValue) + + slice := []string{"foo", "bar"} + err = arrayType.AssignTo(&slice) + require.NoError(t, err) + require.Nil(t, slice) + + err = arrayType.Set([]string{}) + require.NoError(t, err) + + gotValue = arrayType.Get() + require.Len(t, gotValue, 0) + + err = arrayType.AssignTo(&slice) + require.NoError(t, err) + require.EqualValues(t, []string{}, slice) + + err = arrayType.Set([]string{"baz", "quz"}) + require.NoError(t, err) + + gotValue = arrayType.Get() + require.Len(t, gotValue, 2) + + err = arrayType.AssignTo(&slice) + require.NoError(t, err) + require.EqualValues(t, []string{"baz", "quz"}, slice) +} + +func TestArrayTypeTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), + Name: "_text", + OID: pgtype.TextArrayOID, + }) + + var dstStrings []string + err := conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) + require.NoError(t, err) + + require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) +} + +func TestArrayTypeEmptyArrayDoesNotBreakArrayType(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: pgtype.NewArrayType("_text", pgtype.TextOID, func() pgtype.ValueTranscoder { return &pgtype.Text{} }), + Name: "_text", + OID: pgtype.TextArrayOID, + }) + + var dstStrings []string + err := conn.QueryRow(context.Background(), "select '{}'::text[]").Scan(&dstStrings) + require.NoError(t, err) + + require.EqualValues(t, []string{}, dstStrings) + + err = conn.QueryRow(context.Background(), "select $1::text[]", []string{"red", "green", "blue"}).Scan(&dstStrings) + require.NoError(t, err) + + require.EqualValues(t, []string{"red", "green", "blue"}, dstStrings) +} diff --git a/pgtype/bit.go b/pgtype/bit.go new file mode 100644 index 00000000..c1709e6b --- /dev/null +++ b/pgtype/bit.go @@ -0,0 +1,45 @@ +package pgtype + +import ( + "database/sql/driver" +) + +type Bit Varbit + +func (dst *Bit) Set(src interface{}) error { + return (*Varbit)(dst).Set(src) +} + +func (dst Bit) Get() interface{} { + return (Varbit)(dst).Get() +} + +func (src *Bit) AssignTo(dst interface{}) error { + return (*Varbit)(src).AssignTo(dst) +} + +func (dst *Bit) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Varbit)(dst).DecodeBinary(ci, src) +} + +func (src Bit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Varbit)(src).EncodeBinary(ci, buf) +} + +func (dst *Bit) DecodeText(ci *ConnInfo, src []byte) error { + return (*Varbit)(dst).DecodeText(ci, src) +} + +func (src Bit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Varbit)(src).EncodeText(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Bit) Scan(src interface{}) error { + return (*Varbit)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bit) Value() (driver.Value, error) { + return (Varbit)(src).Value() +} diff --git a/pgtype/bit_test.go b/pgtype/bit_test.go new file mode 100644 index 00000000..df5fe4cb --- /dev/null +++ b/pgtype/bit_test.go @@ -0,0 +1,25 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestBitTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bit(40)", []interface{}{ + &pgtype.Varbit{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, + &pgtype.Varbit{}, + }) +} + +func TestBitNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select B'111111111'", + Value: &pgtype.Bit{Bytes: []byte{255, 128}, Len: 9, Valid: true}, + }, + }) +} diff --git a/pgtype/bool.go b/pgtype/bool.go new file mode 100644 index 00000000..4fcc67e3 --- /dev/null +++ b/pgtype/bool.go @@ -0,0 +1,197 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "strconv" +) + +type Bool struct { + Bool bool + Valid bool +} + +func (dst *Bool) Set(src interface{}) error { + if src == nil { + *dst = Bool{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case bool: + *dst = Bool{Bool: value, Valid: true} + case string: + bb, err := strconv.ParseBool(value) + if err != nil { + return err + } + *dst = Bool{Bool: bb, Valid: true} + case *bool: + if value == nil { + *dst = Bool{} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Bool{} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingBoolType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bool", value) + } + + return nil +} + +func (dst Bool) Get() interface{} { + if !dst.Valid { + return nil + } + + return dst.Bool +} + +func (src *Bool) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *bool: + *v = src.Bool + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Bool{} + return nil + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + *dst = Bool{Bool: src[0] == 't', Valid: true} + return nil +} + +func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Bool{} + return nil + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + *dst = Bool{Bool: src[0] == 1, Valid: true} + return nil +} + +func (src Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if src.Bool { + buf = append(buf, 't') + } else { + buf = append(buf, 'f') + } + + return buf, nil +} + +func (src Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if src.Bool { + buf = append(buf, 1) + } else { + buf = append(buf, 0) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Bool) Scan(src interface{}) error { + if src == nil { + *dst = Bool{} + return nil + } + + switch src := src.(type) { + case bool: + *dst = Bool{Bool: src, Valid: true} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bool) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + return src.Bool, nil +} + +func (src Bool) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + if src.Bool { + return []byte("true"), nil + } else { + return []byte("false"), nil + } +} + +func (dst *Bool) UnmarshalJSON(b []byte) error { + var v *bool + err := json.Unmarshal(b, &v) + if err != nil { + return err + } + + if v == nil { + *dst = Bool{} + } else { + *dst = Bool{Bool: *v, Valid: true} + } + + return nil +} diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go new file mode 100644 index 00000000..a282fd6b --- /dev/null +++ b/pgtype/bool_array.go @@ -0,0 +1,504 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type BoolArray struct { + Elements []Bool + Dimensions []ArrayDimension + Valid bool +} + +func (dst *BoolArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = BoolArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []bool: + if value == nil { + *dst = BoolArray{} + } else if len(value) == 0 { + *dst = BoolArray{Valid: true} + } else { + elements := make([]Bool, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BoolArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*bool: + if value == nil { + *dst = BoolArray{} + } else if len(value) == 0 { + *dst = BoolArray{Valid: true} + } else { + elements := make([]Bool, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BoolArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Bool: + if value == nil { + *dst = BoolArray{} + } else if len(value) == 0 { + *dst = BoolArray{Valid: true} + } else { + *dst = BoolArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = BoolArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for BoolArray", src) + } + if elementsLength == 0 { + *dst = BoolArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to BoolArray", src) + } + + *dst = BoolArray{ + Elements: make([]Bool, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Bool, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to BoolArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *BoolArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to BoolArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in BoolArray", err) + } + index++ + + return index, nil +} + +func (dst BoolArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *BoolArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]bool: + *v = make([]bool, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*bool: + *v = make([]*bool, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from BoolArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from BoolArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BoolArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Bool + + if len(uta.Elements) > 0 { + elements = make([]Bool, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Bool + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = BoolArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BoolArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = BoolArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Bool, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = BoolArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("bool"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "bool") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *BoolArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src BoolArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/bool_array_test.go b/pgtype/bool_array_test.go new file mode 100644 index 00000000..cfb9ad79 --- /dev/null +++ b/pgtype/bool_array_test.go @@ -0,0 +1,283 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestBoolArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bool[]", []interface{}{ + &pgtype.BoolArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.BoolArray{}, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {}, + {Bool: false, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestBoolArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.BoolArray + }{ + { + source: []bool{true}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]bool)(nil)), + result: pgtype.BoolArray{}, + }, + { + source: [][]bool{{true}, {false}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]bool{{true}, {false}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.BoolArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestBoolArrayAssignTo(t *testing.T) { + var boolSlice []bool + type _boolSlice []bool + var namedBoolSlice _boolSlice + var boolSliceDim2 [][]bool + var boolSliceDim4 [][][][]bool + var boolArrayDim2 [2][1]bool + var boolArrayDim4 [2][1][1][3]bool + + simpleTests := []struct { + src pgtype.BoolArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &boolSlice, + expected: []bool{true}, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &namedBoolSlice, + expected: _boolSlice{true}, + }, + { + src: pgtype.BoolArray{}, + dst: &boolSlice, + expected: (([]bool)(nil)), + }, + { + src: pgtype.BoolArray{Valid: true}, + dst: &boolSlice, + expected: []bool{}, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + expected: [][]bool{{true}, {false}}, + dst: &boolSliceDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + expected: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + dst: &boolSliceDim4, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + expected: [2][1]bool{{true}, {false}}, + dst: &boolArrayDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{ + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: true, Valid: true}, + {Bool: false, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + expected: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, + dst: &boolArrayDim4, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.BoolArray + dst interface{} + }{ + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &boolSlice, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &boolArrayDim2, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &boolSlice, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &boolArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go new file mode 100644 index 00000000..a1ba9bb0 --- /dev/null +++ b/pgtype/bool_test.go @@ -0,0 +1,140 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestBoolTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bool", []interface{}{ + &pgtype.Bool{Bool: false, Valid: true}, + &pgtype.Bool{Bool: true, Valid: true}, + &pgtype.Bool{Bool: false}, + }) +} + +func TestBoolSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Bool + }{ + {source: true, result: pgtype.Bool{Bool: true, Valid: true}}, + {source: false, result: pgtype.Bool{Bool: false, Valid: true}}, + {source: "true", result: pgtype.Bool{Bool: true, Valid: true}}, + {source: "false", result: pgtype.Bool{Bool: false, Valid: true}}, + {source: "t", result: pgtype.Bool{Bool: true, Valid: true}}, + {source: "f", result: pgtype.Bool{Bool: false, Valid: true}}, + {source: _bool(true), result: pgtype.Bool{Bool: true, Valid: true}}, + {source: _bool(false), result: pgtype.Bool{Bool: false, Valid: true}}, + {source: nil, result: pgtype.Bool{}}, + } + + for i, tt := range successfulTests { + var r pgtype.Bool + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestBoolAssignTo(t *testing.T) { + var b bool + var _b _bool + var pb *bool + var _pb *_bool + + simpleTests := []struct { + src pgtype.Bool + dst interface{} + expected interface{} + }{ + {src: pgtype.Bool{Bool: false, Valid: true}, dst: &b, expected: false}, + {src: pgtype.Bool{Bool: true, Valid: true}, dst: &b, expected: true}, + {src: pgtype.Bool{Bool: false, Valid: true}, dst: &_b, expected: _bool(false)}, + {src: pgtype.Bool{Bool: true, Valid: true}, dst: &_b, expected: _bool(true)}, + {src: pgtype.Bool{Bool: false}, dst: &pb, expected: ((*bool)(nil))}, + {src: pgtype.Bool{Bool: false}, dst: &_pb, expected: ((*_bool)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Bool + dst interface{} + expected interface{} + }{ + {src: pgtype.Bool{Bool: true, Valid: true}, dst: &pb, expected: true}, + {src: pgtype.Bool{Bool: true, Valid: true}, dst: &_pb, expected: _bool(true)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} + +func TestBoolMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Bool + result string + }{ + {source: pgtype.Bool{}, result: "null"}, + {source: pgtype.Bool{Bool: true, Valid: true}, result: "true"}, + {source: pgtype.Bool{Bool: false, Valid: true}, result: "false"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestBoolUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Bool + }{ + {source: "null", result: pgtype.Bool{}}, + {source: "true", result: pgtype.Bool{Bool: true, Valid: true}}, + {source: "false", result: pgtype.Bool{Bool: false, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Bool + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/box.go b/pgtype/box.go new file mode 100644 index 00000000..868b40a2 --- /dev/null +++ b/pgtype/box.go @@ -0,0 +1,155 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Box struct { + P [2]Vec2 + Valid bool +} + +func (dst *Box) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Box", src) +} + +func (dst Box) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Box) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Box{} + return nil + } + + if len(src) < 11 { + return fmt.Errorf("invalid length for Box: %v", len(src)) + } + + str := string(src[1:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1 : len(str)-1] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true} + return nil +} + +func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Box{} + return nil + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for Box: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + *dst = Box{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Valid: true, + } + return nil +} + +func (src Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(src.P[0].X, 'f', -1, 64), + strconv.FormatFloat(src.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(src.P[1].X, 'f', -1, 64), + strconv.FormatFloat(src.P[1].Y, 'f', -1, 64), + )...) + return buf, nil +} + +func (src Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Box) Scan(src interface{}) error { + if src == nil { + *dst = Box{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Box) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/box_test.go b/pgtype/box_test.go new file mode 100644 index 00000000..c7e00553 --- /dev/null +++ b/pgtype/box_test.go @@ -0,0 +1,34 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestBoxTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "box", []interface{}{ + &pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, + Valid: true, + }, + &pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Valid: true, + }, + &pgtype.Box{}, + }) +} + +func TestBoxNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select '3.14, 1.678, 7.1, 5.234'::box", + Value: &pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + Valid: true, + }, + }, + }) +} diff --git a/pgtype/bpchar.go b/pgtype/bpchar.go new file mode 100644 index 00000000..2e899ea8 --- /dev/null +++ b/pgtype/bpchar.go @@ -0,0 +1,92 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +// BPChar is fixed-length, blank padded char type +// character(n), char(n) +type BPChar Text + +// Set converts from src to dst. +func (dst *BPChar) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +// Get returns underlying value +func (dst BPChar) Get() interface{} { + return (Text)(dst).Get() +} + +// AssignTo assigns from src to dst. +func (src *BPChar) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *rune: + runes := []rune(src.String) + if len(runes) == 1 { + *v = runes[0] + return nil + } + case *string: + *v = src.String + return nil + case *[]byte: + *v = make([]byte, len(src.String)) + copy(*v, src.String) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (BPChar) PreferredResultFormat() int16 { + return TextFormatCode +} + +func (dst *BPChar) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *BPChar) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (BPChar) PreferredParamFormat() int16 { + return TextFormatCode +} + +func (src BPChar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeText(ci, buf) +} + +func (src BPChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *BPChar) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src BPChar) Value() (driver.Value, error) { + return (Text)(src).Value() +} + +func (src BPChar) MarshalJSON() ([]byte, error) { + return (Text)(src).MarshalJSON() +} + +func (dst *BPChar) UnmarshalJSON(b []byte) error { + return (*Text)(dst).UnmarshalJSON(b) +} diff --git a/pgtype/bpchar_array.go b/pgtype/bpchar_array.go new file mode 100644 index 00000000..c73c78a3 --- /dev/null +++ b/pgtype/bpchar_array.go @@ -0,0 +1,504 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type BPCharArray struct { + Elements []BPChar + Dimensions []ArrayDimension + Valid bool +} + +func (dst *BPCharArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = BPCharArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []string: + if value == nil { + *dst = BPCharArray{} + } else if len(value) == 0 { + *dst = BPCharArray{Valid: true} + } else { + elements := make([]BPChar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BPCharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*string: + if value == nil { + *dst = BPCharArray{} + } else if len(value) == 0 { + *dst = BPCharArray{Valid: true} + } else { + elements := make([]BPChar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BPCharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []BPChar: + if value == nil { + *dst = BPCharArray{} + } else if len(value) == 0 { + *dst = BPCharArray{Valid: true} + } else { + *dst = BPCharArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = BPCharArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for BPCharArray", src) + } + if elementsLength == 0 { + *dst = BPCharArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to BPCharArray", src) + } + + *dst = BPCharArray{ + Elements: make([]BPChar, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]BPChar, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to BPCharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *BPCharArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to BPCharArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in BPCharArray", err) + } + index++ + + return index, nil +} + +func (dst BPCharArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *BPCharArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from BPCharArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from BPCharArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BPCharArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []BPChar + + if len(uta.Elements) > 0 { + elements = make([]BPChar, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem BPChar + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = BPCharArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *BPCharArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BPCharArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = BPCharArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]BPChar, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = BPCharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("bpchar"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "bpchar") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *BPCharArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src BPCharArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/bpchar_array_test.go b/pgtype/bpchar_array_test.go new file mode 100644 index 00000000..277f6e3c --- /dev/null +++ b/pgtype/bpchar_array_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestBPCharArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "char(8)[]", []interface{}{ + &pgtype.BPCharArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.BPCharArray{ + Elements: []pgtype.BPChar{ + pgtype.BPChar{String: "foo ", Valid: true}, + pgtype.BPChar{}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.BPCharArray{}, + &pgtype.BPCharArray{ + Elements: []pgtype.BPChar{ + pgtype.BPChar{String: "bar ", Valid: true}, + pgtype.BPChar{String: "NuLL ", Valid: true}, + pgtype.BPChar{String: `wow"quz\`, Valid: true}, + pgtype.BPChar{String: "1 ", Valid: true}, + pgtype.BPChar{String: "1 ", Valid: true}, + pgtype.BPChar{String: "null ", Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 3, LowerBound: 1}, + {Length: 2, LowerBound: 1}, + }, + Valid: true, + }, + &pgtype.BPCharArray{ + Elements: []pgtype.BPChar{ + pgtype.BPChar{String: " bar ", Valid: true}, + pgtype.BPChar{String: " baz ", Valid: true}, + pgtype.BPChar{String: " quz ", Valid: true}, + pgtype.BPChar{String: "foo ", Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} diff --git a/pgtype/bpchar_test.go b/pgtype/bpchar_test.go new file mode 100644 index 00000000..fe7e651c --- /dev/null +++ b/pgtype/bpchar_test.go @@ -0,0 +1,51 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestChar3Transcode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "char(3)", []interface{}{ + &pgtype.BPChar{String: "a ", Valid: true}, + &pgtype.BPChar{String: " a ", Valid: true}, + &pgtype.BPChar{String: "嗨 ", Valid: true}, + &pgtype.BPChar{String: " ", Valid: true}, + &pgtype.BPChar{}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.BPChar) + b := bb.(pgtype.BPChar) + + return a.Valid == b.Valid && a.String == b.String + }) +} + +func TestBPCharAssignTo(t *testing.T) { + var ( + str string + run rune + ) + simpleTests := []struct { + src pgtype.BPChar + dst interface{} + expected interface{} + }{ + {src: pgtype.BPChar{String: "simple", Valid: true}, dst: &str, expected: "simple"}, + {src: pgtype.BPChar{String: "嗨", Valid: true}, dst: &run, expected: '嗨'}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + +} diff --git a/pgtype/bytea.go b/pgtype/bytea.go new file mode 100644 index 00000000..d4c4e436 --- /dev/null +++ b/pgtype/bytea.go @@ -0,0 +1,146 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/hex" + "fmt" +) + +type Bytea struct { + Bytes []byte + Valid bool +} + +func (dst *Bytea) Set(src interface{}) error { + if src == nil { + *dst = Bytea{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case []byte: + if value != nil { + *dst = Bytea{Bytes: value, Valid: true} + } else { + *dst = Bytea{} + } + default: + if originalSrc, ok := underlyingBytesType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bytea", value) + } + + return nil +} + +func (dst Bytea) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.Bytes +} + +func (src *Bytea) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *[]byte: + buf := make([]byte, len(src.Bytes)) + copy(buf, src.Bytes) + *v = buf + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +// DecodeText only supports the hex format. This has been the default since +// PostgreSQL 9.0. +func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Bytea{} + return nil + } + + if len(src) < 2 || src[0] != '\\' || src[1] != 'x' { + return fmt.Errorf("invalid hex format") + } + + buf := make([]byte, (len(src)-2)/2) + _, err := hex.Decode(buf, src[2:]) + if err != nil { + return err + } + + *dst = Bytea{Bytes: buf, Valid: true} + return nil +} + +func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Bytea{} + return nil + } + + *dst = Bytea{Bytes: src, Valid: true} + return nil +} + +func (src Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = append(buf, `\x`...) + buf = append(buf, hex.EncodeToString(src.Bytes)...) + return buf, nil +} + +func (src Bytea) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, src.Bytes...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Bytea) Scan(src interface{}) error { + if src == nil { + *dst = Bytea{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + buf := make([]byte, len(src)) + copy(buf, src) + *dst = Bytea{Bytes: buf, Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bytea) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return src.Bytes, nil +} diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go new file mode 100644 index 00000000..7c539e21 --- /dev/null +++ b/pgtype/bytea_array.go @@ -0,0 +1,476 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type ByteaArray struct { + Elements []Bytea + Dimensions []ArrayDimension + Valid bool +} + +func (dst *ByteaArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = ByteaArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case [][]byte: + if value == nil { + *dst = ByteaArray{} + } else if len(value) == 0 { + *dst = ByteaArray{Valid: true} + } else { + elements := make([]Bytea, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = ByteaArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Bytea: + if value == nil { + *dst = ByteaArray{} + } else if len(value) == 0 { + *dst = ByteaArray{Valid: true} + } else { + *dst = ByteaArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = ByteaArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for ByteaArray", src) + } + if elementsLength == 0 { + *dst = ByteaArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to ByteaArray", src) + } + + *dst = ByteaArray{ + Elements: make([]Bytea, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Bytea, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to ByteaArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *ByteaArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to ByteaArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in ByteaArray", err) + } + index++ + + return index, nil +} + +func (dst ByteaArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *ByteaArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from ByteaArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from ByteaArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = ByteaArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Bytea + + if len(uta.Elements) > 0 { + elements = make([]Bytea, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Bytea + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = ByteaArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = ByteaArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = ByteaArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Bytea, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = ByteaArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("bytea"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "bytea") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *ByteaArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src ByteaArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/bytea_array_test.go b/pgtype/bytea_array_test.go new file mode 100644 index 00000000..1473eb9c --- /dev/null +++ b/pgtype/bytea_array_test.go @@ -0,0 +1,229 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestByteaArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bytea[]", []interface{}{ + &pgtype.ByteaArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.ByteaArray{}, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{}, Valid: true}, + {Bytes: []byte{1, 2, 3}, Valid: true}, + {}, + {Bytes: []byte{1}, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{}, Valid: true}, + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{1}, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestByteaArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ByteaArray + }{ + { + source: [][]byte{{1, 2, 3}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([][]byte)(nil)), + result: pgtype.ByteaArray{}, + }, + { + source: [][][]byte{{{1}}, {{2}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Valid: true}, {Bytes: []byte{2}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{4, 5, 6}, Valid: true}, + {Bytes: []byte{7, 8, 9}, Valid: true}, + {Bytes: []byte{10, 11, 12}, Valid: true}, + {Bytes: []byte{13, 14, 15}, Valid: true}, + {Bytes: []byte{16, 17, 18}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1][]byte{{{1}}, {{2}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Valid: true}, {Bytes: []byte{2}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{4, 5, 6}, Valid: true}, + {Bytes: []byte{7, 8, 9}, Valid: true}, + {Bytes: []byte{10, 11, 12}, Valid: true}, + {Bytes: []byte{13, 14, 15}, Valid: true}, + {Bytes: []byte{16, 17, 18}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.ByteaArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestByteaArrayAssignTo(t *testing.T) { + var byteByteSlice [][]byte + var byteByteSliceDim2 [][][]byte + var byteByteSliceDim4 [][][][][]byte + var byteByteArraySliceDim2 [2][1][]byte + var byteByteArraySliceDim4 [2][1][1][3][]byte + + simpleTests := []struct { + src pgtype.ByteaArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &byteByteSlice, + expected: [][]byte{{1, 2, 3}}, + }, + { + src: pgtype.ByteaArray{}, + dst: &byteByteSlice, + expected: (([][]byte)(nil)), + }, + { + src: pgtype.ByteaArray{Valid: true}, + dst: &byteByteSlice, + expected: [][]byte{}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Valid: true}, {Bytes: []byte{2}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &byteByteSliceDim2, + expected: [][][]byte{{{1}}, {{2}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{4, 5, 6}, Valid: true}, + {Bytes: []byte{7, 8, 9}, Valid: true}, + {Bytes: []byte{10, 11, 12}, Valid: true}, + {Bytes: []byte{13, 14, 15}, Valid: true}, + {Bytes: []byte{16, 17, 18}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &byteByteSliceDim4, + expected: [][][][][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1}, Valid: true}, {Bytes: []byte{2}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &byteByteArraySliceDim2, + expected: [2][1][]byte{{{1}}, {{2}}}, + }, + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + {Bytes: []byte{1, 2, 3}, Valid: true}, + {Bytes: []byte{4, 5, 6}, Valid: true}, + {Bytes: []byte{7, 8, 9}, Valid: true}, + {Bytes: []byte{10, 11, 12}, Valid: true}, + {Bytes: []byte{13, 14, 15}, Valid: true}, + {Bytes: []byte{16, 17, 18}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &byteByteArraySliceDim4, + expected: [2][1][1][3][]byte{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}}, {{{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go new file mode 100644 index 00000000..0f47cb7f --- /dev/null +++ b/pgtype/bytea_test.go @@ -0,0 +1,73 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestByteaTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bytea", []interface{}{ + &pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, + &pgtype.Bytea{Bytes: []byte{}, Valid: true}, + &pgtype.Bytea{Bytes: nil}, + }) +} + +func TestByteaSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Bytea + }{ + {source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}}, + {source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Valid: true}}, + {source: []byte(nil), result: pgtype.Bytea{}}, + {source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}}, + {source: _byteSlice(nil), result: pgtype.Bytea{}}, + } + + for i, tt := range successfulTests { + var r pgtype.Bytea + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestByteaAssignTo(t *testing.T) { + var buf []byte + var _buf _byteSlice + var pbuf *[]byte + var _pbuf *_byteSlice + + simpleTests := []struct { + src pgtype.Bytea + dst interface{} + expected interface{} + }{ + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, dst: &buf, expected: []byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, dst: &_buf, expected: _byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, dst: &pbuf, expected: &[]byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Valid: true}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{}, dst: &pbuf, expected: ((*[]byte)(nil))}, + {src: pgtype.Bytea{}, dst: &_pbuf, expected: ((*_byteSlice)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/cid.go b/pgtype/cid.go new file mode 100644 index 00000000..b944748c --- /dev/null +++ b/pgtype/cid.go @@ -0,0 +1,61 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// CID is PostgreSQL's Command Identifier type. +// +// When one does +// +// select cmin, cmax, * from some_table; +// +// it is the data type of the cmin and cmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/c.h as CommandId +// in the PostgreSQL sources. +type CID pguint32 + +// Set converts from src to dst. Note that as CID is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *CID) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst CID) Get() interface{} { + return (pguint32)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as CID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *CID) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *CID) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) +} + +func (dst *CID) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) +} + +func (src CID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeText(ci, buf) +} + +func (src CID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *CID) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src CID) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go new file mode 100644 index 00000000..041cb805 --- /dev/null +++ b/pgtype/cid_test.go @@ -0,0 +1,102 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestCIDTranscode(t *testing.T) { + pgTypeName := "cid" + values := []interface{}{ + &pgtype.CID{Uint: 42, Valid: true}, + &pgtype.CID{}, + } + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) +} + +func TestCIDSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CID + }{ + {source: uint32(1), result: pgtype.CID{Uint: 1, Valid: true}}, + } + + for i, tt := range successfulTests { + var r pgtype.CID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.CID + dst interface{} + expected interface{} + }{ + {src: pgtype.CID{Uint: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.CID{}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.CID + dst interface{} + expected interface{} + }{ + {src: pgtype.CID{Uint: 42, Valid: true}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.CID + dst interface{} + }{ + {src: pgtype.CID{}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/cidr.go b/pgtype/cidr.go new file mode 100644 index 00000000..2241ca1c --- /dev/null +++ b/pgtype/cidr.go @@ -0,0 +1,31 @@ +package pgtype + +type CIDR Inet + +func (dst *CIDR) Set(src interface{}) error { + return (*Inet)(dst).Set(src) +} + +func (dst CIDR) Get() interface{} { + return (Inet)(dst).Get() +} + +func (src *CIDR) AssignTo(dst interface{}) error { + return (*Inet)(src).AssignTo(dst) +} + +func (dst *CIDR) DecodeText(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeText(ci, src) +} + +func (dst *CIDR) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeBinary(ci, src) +} + +func (src CIDR) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Inet)(src).EncodeText(ci, buf) +} + +func (src CIDR) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Inet)(src).EncodeBinary(ci, buf) +} diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go new file mode 100644 index 00000000..48a6a4c1 --- /dev/null +++ b/pgtype/cidr_array.go @@ -0,0 +1,533 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "net" + "reflect" + + "github.com/jackc/pgio" +) + +type CIDRArray struct { + Elements []CIDR + Dimensions []ArrayDimension + Valid bool +} + +func (dst *CIDRArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = CIDRArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []*net.IPNet: + if value == nil { + *dst = CIDRArray{} + } else if len(value) == 0 { + *dst = CIDRArray{Valid: true} + } else { + elements := make([]CIDR, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CIDRArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []net.IP: + if value == nil { + *dst = CIDRArray{} + } else if len(value) == 0 { + *dst = CIDRArray{Valid: true} + } else { + elements := make([]CIDR, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CIDRArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*net.IP: + if value == nil { + *dst = CIDRArray{} + } else if len(value) == 0 { + *dst = CIDRArray{Valid: true} + } else { + elements := make([]CIDR, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CIDRArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []CIDR: + if value == nil { + *dst = CIDRArray{} + } else if len(value) == 0 { + *dst = CIDRArray{Valid: true} + } else { + *dst = CIDRArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = CIDRArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for CIDRArray", src) + } + if elementsLength == 0 { + *dst = CIDRArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to CIDRArray", src) + } + + *dst = CIDRArray{ + Elements: make([]CIDR, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]CIDR, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to CIDRArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *CIDRArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to CIDRArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in CIDRArray", err) + } + index++ + + return index, nil +} + +func (dst CIDRArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *CIDRArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]*net.IPNet: + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]net.IP: + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*net.IP: + *v = make([]*net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from CIDRArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from CIDRArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CIDRArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []CIDR + + if len(uta.Elements) > 0 { + elements = make([]CIDR, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem CIDR + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = CIDRArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CIDRArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = CIDRArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]CIDR, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = CIDRArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("cidr"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "cidr") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *CIDRArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src CIDRArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/cidr_array_test.go b/pgtype/cidr_array_test.go new file mode 100644 index 00000000..7821cf44 --- /dev/null +++ b/pgtype/cidr_array_test.go @@ -0,0 +1,319 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestCIDRArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "cidr[]", []interface{}{ + &pgtype.CIDRArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.CIDRArray{}, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, + {}, + {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestCIDRArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CIDRArray + }{ + { + source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.CIDRArray{}, + }, + { + source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.CIDRArray{}, + }, + { + source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.CIDRArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCIDRArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + var ipSliceDim2 [][]net.IP + var ipnetSliceDim4 [][][][]*net.IPNet + var ipArrayDim2 [2][1]net.IP + var ipnetArrayDim4 [2][1][1][3]*net.IPNet + + simpleTests := []struct { + src pgtype.CIDRArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.CIDRArray{}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.CIDRArray{Valid: true}, + dst: &ipnetSlice, + expected: []*net.IPNet{}, + }, + { + src: pgtype.CIDRArray{}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + { + src: pgtype.CIDRArray{Valid: true}, + dst: &ipSlice, + expected: []net.IP{}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &ipSliceDim2, + expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &ipnetSliceDim4, + expected: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &ipArrayDim2, + expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &ipnetArrayDim4, + expected: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/circle.go b/pgtype/circle.go new file mode 100644 index 00000000..7524d7b9 --- /dev/null +++ b/pgtype/circle.go @@ -0,0 +1,140 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Circle struct { + P Vec2 + R float64 + Valid bool +} + +func (dst *Circle) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Circle", src) +} + +func (dst Circle) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Circle) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Circle{} + return nil + } + + if len(src) < 9 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + str := string(src[2:]) + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+2 : len(str)-1] + + r, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Circle{P: Vec2{x, y}, R: r, Valid: true} + return nil +} + +func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Circle{} + return nil + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + r := binary.BigEndian.Uint64(src[16:]) + + *dst = Circle{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + R: math.Float64frombits(r), + Valid: true, + } + return nil +} + +func (src Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, + strconv.FormatFloat(src.P.X, 'f', -1, 64), + strconv.FormatFloat(src.P.Y, 'f', -1, 64), + strconv.FormatFloat(src.R, 'f', -1, 64), + )...) + + return buf, nil +} + +func (src Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.R)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Circle) Scan(src interface{}) error { + if src == nil { + *dst = Circle{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Circle) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go new file mode 100644 index 00000000..416a1a41 --- /dev/null +++ b/pgtype/circle_test.go @@ -0,0 +1,16 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestCircleTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "circle", []interface{}{ + &pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, + &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Valid: true}, + &pgtype.Circle{}, + }) +} diff --git a/pgtype/composite_bench_test.go b/pgtype/composite_bench_test.go new file mode 100644 index 00000000..a1d91f8e --- /dev/null +++ b/pgtype/composite_bench_test.go @@ -0,0 +1,192 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgio" + "github.com/jackc/pgtype" + "github.com/stretchr/testify/require" +) + +type MyCompositeRaw struct { + A int32 + B *string +} + +func (src MyCompositeRaw) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + buf = pgio.AppendUint32(buf, 2) + + buf = pgio.AppendUint32(buf, pgtype.Int4OID) + buf = pgio.AppendInt32(buf, 4) + buf = pgio.AppendInt32(buf, src.A) + + buf = pgio.AppendUint32(buf, pgtype.TextOID) + if src.B != nil { + buf = pgio.AppendInt32(buf, int32(len(*src.B))) + buf = append(buf, (*src.B)...) + } else { + buf = pgio.AppendInt32(buf, -1) + } + + return buf, nil +} + +func (dst *MyCompositeRaw) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + a := pgtype.Int4{} + b := pgtype.Text{} + + scanner := pgtype.NewCompositeBinaryScanner(ci, src) + scanner.ScanDecoder(&a) + scanner.ScanDecoder(&b) + + if scanner.Err() != nil { + return scanner.Err() + } + + dst.A = a.Int + if b.Valid { + dst.B = &b.String + } else { + dst.B = nil + } + + return nil +} + +var x []byte + +func BenchmarkBinaryEncodingManual(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + v := MyCompositeRaw{4, ptrS("ABCDEFG")} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + buf, _ = v.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +func BenchmarkBinaryEncodingHelper(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + v := MyType{4, ptrS("ABCDEFG")} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + buf, _ = v.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +func BenchmarkBinaryEncodingComposite(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + f1 := 2 + f2 := ptrS("bar") + c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.Int4OID}, + {"b", pgtype.TextOID}, + }, ci) + require.NoError(b, err) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + c.Set([]interface{}{f1, f2}) + buf, _ = c.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +func BenchmarkBinaryEncodingJSON(b *testing.B) { + buf := make([]byte, 0, 128) + ci := pgtype.NewConnInfo() + v := MyCompositeRaw{4, ptrS("ABCDEFG")} + j := pgtype.JSON{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + j.Set(v) + buf, _ = j.EncodeBinary(ci, buf[:0]) + } + x = buf +} + +var dstRaw MyCompositeRaw + +func BenchmarkBinaryDecodingManual(b *testing.B) { + ci := pgtype.NewConnInfo() + buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) + dst := MyCompositeRaw{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := dst.DecodeBinary(ci, buf) + E(err) + } + dstRaw = dst +} + +var dstMyType MyType + +func BenchmarkBinaryDecodingHelpers(b *testing.B) { + ci := pgtype.NewConnInfo() + buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) + dst := MyType{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := dst.DecodeBinary(ci, buf) + E(err) + } + dstMyType = dst +} + +var gf1 int +var gf2 *string + +func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { + ci := pgtype.NewConnInfo() + buf, _ := MyType{4, ptrS("ABCDEFG")}.EncodeBinary(ci, nil) + var f1 int + var f2 *string + + c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.Int4OID}, + {"b", pgtype.TextOID}, + }, ci) + require.NoError(b, err) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := c.DecodeBinary(ci, buf) + if err != nil { + b.Fatal(err) + } + err = c.AssignTo([]interface{}{&f1, &f2}) + if err != nil { + b.Fatal(err) + } + } + gf1 = f1 + gf2 = f2 +} + +func BenchmarkBinaryDecodingJSON(b *testing.B) { + ci := pgtype.NewConnInfo() + j := pgtype.JSON{} + j.Set(MyCompositeRaw{4, ptrS("ABCDEFG")}) + buf, _ := j.EncodeBinary(ci, nil) + + j = pgtype.JSON{} + dst := MyCompositeRaw{} + + b.ResetTimer() + for n := 0; n < b.N; n++ { + err := j.DecodeBinary(ci, buf) + E(err) + err = j.AssignTo(&dst) + E(err) + } + dstRaw = dst +} diff --git a/pgtype/composite_fields.go b/pgtype/composite_fields.go new file mode 100644 index 00000000..e7ca89c7 --- /dev/null +++ b/pgtype/composite_fields.go @@ -0,0 +1,107 @@ +package pgtype + +import "fmt" + +// CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a +// nullable value use a *CompositeFields. It will be set to nil in case of null. +// +// CompositeFields implements EncodeBinary and EncodeText. However, functionality is limited due to CompositeFields not +// knowing the PostgreSQL schema of the composite type. Prefer using a registered CompositeType. +type CompositeFields []interface{} + +func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error { + if len(cf) == 0 { + return fmt.Errorf("cannot decode into empty CompositeFields") + } + + if src == nil { + return fmt.Errorf("cannot decode unexpected null into CompositeFields") + } + + scanner := NewCompositeBinaryScanner(ci, src) + + for _, f := range cf { + scanner.ScanValue(f) + } + + if scanner.Err() != nil { + return scanner.Err() + } + + return nil +} + +func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error { + if len(cf) == 0 { + return fmt.Errorf("cannot decode into empty CompositeFields") + } + + if src == nil { + return fmt.Errorf("cannot decode unexpected null into CompositeFields") + } + + scanner := NewCompositeTextScanner(ci, src) + + for _, f := range cf { + scanner.ScanValue(f) + } + + if scanner.Err() != nil { + return scanner.Err() + } + + return nil +} + +// EncodeText encodes composite fields into the text format. Prefer registering a CompositeType to using +// CompositeFields to encode directly. +func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + b := NewCompositeTextBuilder(ci, buf) + + for _, f := range cf { + if paramEncoder, ok := f.(ParamEncoder); ok { + b.AppendEncoder(paramEncoder) + } else { + b.AppendValue(f) + } + } + + return b.Finish() +} + +// EncodeBinary encodes composite fields into the binary format. Unlike CompositeType the schema of the destination is +// unknown. Prefer registering a CompositeType to using CompositeFields to encode directly. Because the binary +// composite format requires the OID of each field to be specified the only types that will work are those known to +// ConnInfo. +// +// In particular: +// +// * Nil cannot be used because there is no way to determine what type it. +// * Integer types must be exact matches. e.g. A Go int32 into a PostgreSQL bigint will fail. +// * No dereferencing will be done. e.g. *Text must be used instead of Text. +func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + b := NewCompositeBinaryBuilder(ci, buf) + + for _, f := range cf { + dt, ok := ci.DataTypeForValue(f) + if !ok { + return nil, fmt.Errorf("Unknown OID for %#v", f) + } + + if paramEncoder, ok := f.(ParamEncoder); ok { + b.AppendEncoder(dt.OID, paramEncoder) + } else { + err := dt.Value.Set(f) + if err != nil { + return nil, err + } + if paramEncoder, ok := dt.Value.(ParamEncoder); ok { + b.AppendEncoder(dt.OID, paramEncoder) + } else { + return nil, fmt.Errorf("Cannot encode binary format for %v", f) + } + } + } + + return b.Finish() +} diff --git a/pgtype/composite_fields_test.go b/pgtype/composite_fields_test.go new file mode 100644 index 00000000..be0b8125 --- /dev/null +++ b/pgtype/composite_fields_test.go @@ -0,0 +1,273 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCompositeFieldsDecode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + formats := []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} + + // Assorted values + { + var a int32 + var b string + var c float64 + + for _, format := range formats { + err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b, &c}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.EqualValuesf(t, 1, a, "Format: %v", format) + assert.EqualValuesf(t, "hi", b, "Format: %v", format) + assert.EqualValuesf(t, 2.1, c, "Format: %v", format) + } + } + + // nulls, string "null", and empty string fields + { + var a pgtype.Text + var b string + var c pgtype.Text + var d string + var e pgtype.Text + + for _, format := range formats { + err := conn.QueryRow(context.Background(), "select row(null,'null',null,'',null)", pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.Nilf(t, a.Get(), "Format: %v", format) + assert.EqualValuesf(t, "null", b, "Format: %v", format) + assert.Nilf(t, c.Get(), "Format: %v", format) + assert.EqualValuesf(t, "", d, "Format: %v", format) + assert.Nilf(t, e.Get(), "Format: %v", format) + } + } + + // null record + { + var a pgtype.Text + var b string + cf := pgtype.CompositeFields{&a, &b} + + for _, format := range formats { + // Cannot scan nil into + err := conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( + cf, + ) + if assert.Errorf(t, err, "Format: %v", format) { + continue + } + assert.NotNilf(t, cf, "Format: %v", format) + + // But can scan nil into *pgtype.CompositeFields + err = conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan( + &cf, + ) + if assert.Errorf(t, err, "Format: %v", format) { + continue + } + assert.Nilf(t, cf, "Format: %v", format) + } + } + + // quotes and special characters + { + var a, b, c, d string + + for _, format := range formats { + err := conn.QueryRow(context.Background(), `select row('"', 'foo bar', 'foo''bar', 'baz)bar')`, pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b, &c, &d}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.Equalf(t, `"`, a, "Format: %v", format) + assert.Equalf(t, `foo bar`, b, "Format: %v", format) + assert.Equalf(t, `foo'bar`, c, "Format: %v", format) + assert.Equalf(t, `baz)bar`, d, "Format: %v", format) + } + } + + // arrays + { + var a []string + var b []int64 + + for _, format := range formats { + err := conn.QueryRow(context.Background(), `select row(array['foo', 'bar', 'baz'], array[1,2,3])`, pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, &b}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.EqualValuesf(t, []string{"foo", "bar", "baz"}, a, "Format: %v", format) + assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format) + } + } + + // Skip nil fields + { + var a int32 + var c float64 + + for _, format := range formats { + err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan( + pgtype.CompositeFields{&a, nil, &c}, + ) + if !assert.NoErrorf(t, err, "Format: %v", format) { + continue + } + + assert.EqualValuesf(t, 1, a, "Format: %v", format) + assert.EqualValuesf(t, 2.1, c, "Format: %v", format) + } + } +} + +func TestCompositeFieldsEncode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists cf_encode; + +create type cf_encode as ( + a text, + b int4, + c text, + d float8, + e text +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type cf_encode") + + // Use simple protocol to force text or binary encoding + simpleProtocols := []bool{true, false} + + // Assorted values + { + var a string + var b int32 + var c string + var d float64 + var e string + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{"hi", int32(1), "ok", float64(2.1), "bye"}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "ok", c, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 2.1, d, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "bye", e, "Simple Protocol: %v", simpleProtocol) + } + } + } + + // untyped nil + { + var a pgtype.Text + var b int32 + var c string + var d pgtype.Float8 + var e pgtype.Text + + simpleProtocol := true + err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) + } + + // untyped nil cannot be represented in binary format because CompositeFields does not know the PostgreSQL schema + // of the composite type. + simpleProtocol = false + err = conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{nil, int32(1), "null", nil, nil}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + assert.Errorf(t, err, "Simple Protocol: %v", simpleProtocol) + } + + // nulls, string "null", and empty string fields + { + var a pgtype.Text + var b int32 + var c string + var d pgtype.Float8 + var e pgtype.Text + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{&pgtype.Text{}, int32(1), "null", &pgtype.Float8{}, &pgtype.Text{}}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol) + assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol) + } + } + } + + // quotes and special characters + { + var a string + var b int32 + var c string + var d float64 + var e string + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow( + context.Background(), + `select $1::cf_encode`, + pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{`"`, int32(42), `foo'bar`, float64(1.2), `baz)bar`}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c, &d, &e}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.Equalf(t, `"`, a, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, int32(42), b, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, `foo'bar`, c, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, float64(1.2), d, "Simple Protocol: %v", simpleProtocol) + assert.Equalf(t, `baz)bar`, e, "Simple Protocol: %v", simpleProtocol) + } + } + } +} diff --git a/pgtype/composite_type.go b/pgtype/composite_type.go new file mode 100644 index 00000000..85ab5910 --- /dev/null +++ b/pgtype/composite_type.go @@ -0,0 +1,715 @@ +package pgtype + +import ( + "encoding/binary" + "errors" + "fmt" + "reflect" + "strings" + + "github.com/jackc/pgio" +) + +type CompositeTypeField struct { + Name string + OID uint32 +} + +type CompositeType struct { + valid bool + + typeName string + + fields []CompositeTypeField + valueTranscoders []ValueTranscoder +} + +// NewCompositeType creates a CompositeType from fields and ci. ci is used to find the ValueTranscoders used +// for fields. All field OIDs must be previously registered in ci. +func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo) (*CompositeType, error) { + valueTranscoders := make([]ValueTranscoder, len(fields)) + + for i := range fields { + dt, ok := ci.DataTypeForOID(fields[i].OID) + if !ok { + return nil, fmt.Errorf("no data type registered for oid: %d", fields[i].OID) + } + + value := NewValue(dt.Value) + valueTranscoder, ok := value.(ValueTranscoder) + if !ok { + return nil, fmt.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID) + } + + valueTranscoders[i] = valueTranscoder + } + + return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: valueTranscoders}, nil +} + +// NewCompositeTypeValues creates a CompositeType from fields and values. fields and values must have the same length. +// Prefer NewCompositeType unless overriding the transcoding of fields is required. +func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values []ValueTranscoder) (*CompositeType, error) { + if len(fields) != len(values) { + return nil, errors.New("fields and valueTranscoders must have same length") + } + + return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: values}, nil +} + +func (src CompositeType) Get() interface{} { + if !src.valid { + return nil + } + + results := make(map[string]interface{}, len(src.valueTranscoders)) + for i := range src.valueTranscoders { + results[src.fields[i].Name] = src.valueTranscoders[i].Get() + } + return results +} + +func (ct *CompositeType) NewTypeValue() Value { + a := &CompositeType{ + typeName: ct.typeName, + fields: ct.fields, + valueTranscoders: make([]ValueTranscoder, len(ct.valueTranscoders)), + } + + for i := range ct.valueTranscoders { + a.valueTranscoders[i] = NewValue(ct.valueTranscoders[i]).(ValueTranscoder) + } + + return a +} + +func (ct *CompositeType) TypeName() string { + return ct.typeName +} + +func (ct *CompositeType) Fields() []CompositeTypeField { + return ct.fields +} + +func (dst *CompositeType) setNil() { + dst.valid = false +} + +func (dst *CompositeType) Set(src interface{}) error { + if src == nil { + dst.setNil() + return nil + } + + switch value := src.(type) { + case []interface{}: + if len(value) != len(dst.valueTranscoders) { + return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders)) + } + for i, v := range value { + if err := dst.valueTranscoders[i].Set(v); err != nil { + return err + } + } + dst.valid = true + case *[]interface{}: + if value == nil { + dst.setNil() + return nil + } + return dst.Set(*value) + default: + return fmt.Errorf("Can not convert %v to Composite", src) + } + + return nil +} + +// AssignTo should never be called on composite value directly +func (src CompositeType) AssignTo(dst interface{}) error { + if !src.valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case []interface{}: + if len(v) != len(src.valueTranscoders) { + return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders)) + } + for i := range src.valueTranscoders { + if v[i] == nil { + continue + } + + err := assignToOrSet(src.valueTranscoders[i], v[i]) + if err != nil { + return fmt.Errorf("unable to assign to dst[%d]: %v", i, err) + } + } + return nil + case *[]interface{}: + return src.AssignTo(*v) + default: + if isPtrStruct, err := src.assignToPtrStruct(dst); isPtrStruct { + return err + } + + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +func assignToOrSet(src Value, dst interface{}) error { + assignToErr := src.AssignTo(dst) + if assignToErr != nil { + // Try to use get / set instead -- this avoids every type having to be able to AssignTo type of self. + setSucceeded := false + if setter, ok := dst.(Value); ok { + err := setter.Set(src.Get()) + setSucceeded = err == nil + } + if !setSucceeded { + return assignToErr + } + } + + return nil +} + +func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) { + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() != reflect.Ptr { + return false, nil + } + + if dstValue.IsNil() { + return false, nil + } + + dstElemValue := dstValue.Elem() + dstElemType := dstElemValue.Type() + + if dstElemType.Kind() != reflect.Struct { + return false, nil + } + + exportedFields := make([]int, 0, dstElemType.NumField()) + for i := 0; i < dstElemType.NumField(); i++ { + sf := dstElemType.Field(i) + if sf.PkgPath == "" { + exportedFields = append(exportedFields, i) + } + } + + if len(exportedFields) != len(src.valueTranscoders) { + return false, nil + } + + for i := range exportedFields { + err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface()) + if err != nil { + return true, fmt.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err) + } + } + + return true, nil +} + +func (ct *CompositeType) BinaryFormatSupported() bool { + for _, vt := range ct.valueTranscoders { + if !vt.BinaryFormatSupported() { + return false + } + } + return true +} + +func (ct *CompositeType) TextFormatSupported() bool { + for _, vt := range ct.valueTranscoders { + if !vt.TextFormatSupported() { + return false + } + } + return true +} + +func (ct *CompositeType) PreferredFormat() int16 { + if ct.BinaryFormatSupported() { + return BinaryFormatCode + } + return TextFormatCode +} + +func (dst *CompositeType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + if src == nil { + dst.setNil() + return nil + } + + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src CompositeType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} + +func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { + if !src.valid { + return nil, nil + } + + b := NewCompositeBinaryBuilder(ci, buf) + for i := range src.valueTranscoders { + b.AppendEncoder(src.fields[i].OID, src.valueTranscoders[i]) + } + + return b.Finish() +} + +// DecodeBinary implements BinaryDecoder interface. +// Opposite to Record, fields in a composite act as a "schema" +// and decoding fails if SQL value can't be assigned due to +// type mismatch +func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { + scanner := NewCompositeBinaryScanner(ci, buf) + + for _, f := range dst.valueTranscoders { + scanner.ScanDecoder(f) + } + + if scanner.Err() != nil { + return scanner.Err() + } + + dst.valid = true + + return nil +} + +func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { + scanner := NewCompositeTextScanner(ci, buf) + + for _, f := range dst.valueTranscoders { + scanner.ScanDecoder(f) + } + + if scanner.Err() != nil { + return scanner.Err() + } + + dst.valid = true + + return nil +} + +func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { + if !src.valid { + return nil, nil + } + + b := NewCompositeTextBuilder(ci, buf) + for _, f := range src.valueTranscoders { + b.AppendEncoder(f) + } + + return b.Finish() +} + +type CompositeBinaryScanner struct { + ci *ConnInfo + rp int + src []byte + + fieldCount int32 + fieldBytes []byte + fieldOID uint32 + err error +} + +// NewCompositeBinaryScanner a scanner over a binary encoded composite balue. +func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner { + rp := 0 + if len(src[rp:]) < 4 { + return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)} + } + + fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + return &CompositeBinaryScanner{ + ci: ci, + rp: rp, + src: src, + fieldCount: fieldCount, + } +} + +// ScanDecoder calls Next and decodes the result with d. +func (cfs *CompositeBinaryScanner) ScanDecoder(d ResultDecoder) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = d.DecodeResult(cfs.ci, 0, BinaryFormatCode, cfs.fieldBytes) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// ScanDecoder calls Next and scans the result into d. +func (cfs *CompositeBinaryScanner) ScanValue(d interface{}) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = cfs.ci.Scan(cfs.OID(), BinaryFormatCode, cfs.Bytes(), d) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeBinaryScanner) Next() bool { + if cfs.err != nil { + return false + } + + if cfs.rp == len(cfs.src) { + return false + } + + if len(cfs.src[cfs.rp:]) < 8 { + cfs.err = fmt.Errorf("Record incomplete %v", cfs.src) + return false + } + cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:]) + cfs.rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:]))) + cfs.rp += 4 + + if fieldLen >= 0 { + if len(cfs.src[cfs.rp:]) < fieldLen { + cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src) + return false + } + cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen] + cfs.rp += fieldLen + } else { + cfs.fieldBytes = nil + } + + return true +} + +func (cfs *CompositeBinaryScanner) FieldCount() int { + return int(cfs.fieldCount) +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// OID returns the OID of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) OID() uint32 { + return cfs.fieldOID +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeBinaryScanner) Err() error { + return cfs.err +} + +type CompositeTextScanner struct { + ci *ConnInfo + rp int + src []byte + + fieldBytes []byte + err error +} + +// NewCompositeTextScanner a scanner over a text encoded composite value. +func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner { + if len(src) < 2 { + return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)} + } + + if src[0] != '(' { + return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")} + } + + if src[len(src)-1] != ')' { + return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")} + } + + return &CompositeTextScanner{ + ci: ci, + rp: 1, + src: src, + } +} + +// ScanDecoder calls Next and decodes the result with d. +func (cfs *CompositeTextScanner) ScanDecoder(d ResultDecoder) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = d.DecodeResult(cfs.ci, 0, TextFormatCode, cfs.fieldBytes) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// ScanDecoder calls Next and scans the result into d. +func (cfs *CompositeTextScanner) ScanValue(d interface{}) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = cfs.ci.Scan(0, TextFormatCode, cfs.Bytes(), d) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeTextScanner) Next() bool { + if cfs.err != nil { + return false + } + + if cfs.rp == len(cfs.src) { + return false + } + + switch cfs.src[cfs.rp] { + case ',', ')': // null + cfs.rp++ + cfs.fieldBytes = nil + return true + case '"': // quoted value + cfs.rp++ + cfs.fieldBytes = make([]byte, 0, 16) + for { + ch := cfs.src[cfs.rp] + + if ch == '"' { + cfs.rp++ + if cfs.src[cfs.rp] == '"' { + cfs.fieldBytes = append(cfs.fieldBytes, '"') + cfs.rp++ + } else { + break + } + } else if ch == '\\' { + cfs.rp++ + cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp]) + cfs.rp++ + } else { + cfs.fieldBytes = append(cfs.fieldBytes, ch) + cfs.rp++ + } + } + cfs.rp++ + return true + default: // unquoted value + start := cfs.rp + for { + ch := cfs.src[cfs.rp] + if ch == ',' || ch == ')' { + break + } + cfs.rp++ + } + cfs.fieldBytes = cfs.src[start:cfs.rp] + cfs.rp++ + return true + } +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeTextScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeTextScanner) Err() error { + return cfs.err +} + +type CompositeBinaryBuilder struct { + ci *ConnInfo + buf []byte + startIdx int + fieldCount uint32 + err error +} + +func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder { + startIdx := len(buf) + buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields + return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx} +} + +func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { + if b.err != nil { + return + } + + dt, ok := b.ci.DataTypeForOID(oid) + if !ok { + b.err = fmt.Errorf("unknown data type for OID: %d", oid) + return + } + + err := dt.Value.Set(field) + if err != nil { + b.err = err + return + } + + paramEncoder, ok := dt.Value.(ParamEncoder) + if !ok { + b.err = fmt.Errorf("unable to encode for OID: %d", oid) + return + } + + b.AppendEncoder(oid, paramEncoder) +} + +func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field ParamEncoder) { + if b.err != nil { + return + } + + b.buf = pgio.AppendUint32(b.buf, oid) + lengthPos := len(b.buf) + b.buf = pgio.AppendInt32(b.buf, -1) + fieldBuf, err := field.EncodeParam(b.ci, oid, BinaryFormatCode, b.buf) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf))) + b.buf = fieldBuf + } + + b.fieldCount++ +} + +func (b *CompositeBinaryBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount) + return b.buf, nil +} + +type CompositeTextBuilder struct { + ci *ConnInfo + buf []byte + startIdx int + fieldCount uint32 + err error + fieldBuf [32]byte +} + +func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder { + buf = append(buf, '(') // allocate room for number of fields + return &CompositeTextBuilder{ci: ci, buf: buf} +} + +func (b *CompositeTextBuilder) AppendValue(field interface{}) { + if b.err != nil { + return + } + + if field == nil { + b.buf = append(b.buf, ',') + return + } + + dt, ok := b.ci.DataTypeForValue(field) + if !ok { + b.err = fmt.Errorf("unknown data type for field: %v", field) + return + } + + err := dt.Value.Set(field) + if err != nil { + b.err = err + return + } + + paramEncoder, ok := dt.Value.(ParamEncoder) + if !ok { + b.err = fmt.Errorf("unable to encode for value: %v", field) + return + } + + b.AppendEncoder(paramEncoder) +} + +func (b *CompositeTextBuilder) AppendEncoder(field ParamEncoder) { + if b.err != nil { + return + } + + fieldBuf, err := field.EncodeParam(b.ci, 0, TextFormatCode, b.fieldBuf[0:0]) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...) + } + + b.buf = append(b.buf, ',') +} + +func (b *CompositeTextBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + b.buf[len(b.buf)-1] = ')' + return b.buf, nil +} + +var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteCompositeField(src string) string { + return `"` + quoteCompositeReplacer.Replace(src) + `"` +} + +func quoteCompositeFieldIfNeeded(src string) string { + if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) { + return quoteCompositeField(src) + } + return src +} diff --git a/pgtype/composite_type_test.go b/pgtype/composite_type_test.go new file mode 100644 index 00000000..e06927fa --- /dev/null +++ b/pgtype/composite_type_test.go @@ -0,0 +1,320 @@ +package pgtype_test + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + pgx "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCompositeTypeSetAndGet(t *testing.T) { + ci := pgtype.NewConnInfo() + ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.TextOID}, + {"b", pgtype.Int4OID}, + }, ci) + require.NoError(t, err) + assert.Equal(t, nil, ct.Get()) + + nilTests := []struct { + src interface{} + }{ + {nil}, // nil interface + {(*[]interface{})(nil)}, // typed nil + } + + for i, tt := range nilTests { + err := ct.Set(tt.src) + assert.NoErrorf(t, err, "%d", i) + assert.Equal(t, nil, ct.Get()) + } + + compatibleValuesTests := []struct { + src []interface{} + expected map[string]interface{} + }{ + { + src: []interface{}{"foo", int32(42)}, + expected: map[string]interface{}{"a": "foo", "b": int32(42)}, + }, + { + src: []interface{}{nil, nil}, + expected: map[string]interface{}{"a": nil, "b": nil}, + }, + { + src: []interface{}{&pgtype.Text{String: "hi", Valid: true}, &pgtype.Int4{Int: 7, Valid: true}}, + expected: map[string]interface{}{"a": "hi", "b": int32(7)}, + }, + } + + for i, tt := range compatibleValuesTests { + err := ct.Set(tt.src) + assert.NoErrorf(t, err, "%d", i) + assert.EqualValues(t, tt.expected, ct.Get()) + } +} + +func TestCompositeTypeAssignTo(t *testing.T) { + ci := pgtype.NewConnInfo() + ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.TextOID}, + {"b", pgtype.Int4OID}, + }, ci) + require.NoError(t, err) + + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var a string + var b int32 + + err = ct.AssignTo([]interface{}{&a, &b}) + assert.NoError(t, err) + + assert.Equal(t, "foo", a) + assert.Equal(t, int32(42), b) + } + + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var a pgtype.Text + var b pgtype.Int4 + + err = ct.AssignTo([]interface{}{&a, &b}) + assert.NoError(t, err) + + assert.Equal(t, pgtype.Text{String: "foo", Valid: true}, a) + assert.Equal(t, pgtype.Int4{Int: 42, Valid: true}, b) + } + + // Allow nil destination component as no-op + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var b int32 + + err = ct.AssignTo([]interface{}{nil, &b}) + assert.NoError(t, err) + + assert.Equal(t, int32(42), b) + } + + // *[]interface{} dest when null + { + err := ct.Set(nil) + assert.NoError(t, err) + + var a pgtype.Text + var b pgtype.Int4 + dst := []interface{}{&a, &b} + + err = ct.AssignTo(&dst) + assert.NoError(t, err) + + assert.Nil(t, dst) + } + + // *[]interface{} dest when not null + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + var a pgtype.Text + var b pgtype.Int4 + dst := []interface{}{&a, &b} + + err = ct.AssignTo(&dst) + assert.NoError(t, err) + + assert.NotNil(t, dst) + assert.Equal(t, pgtype.Text{String: "foo", Valid: true}, a) + assert.Equal(t, pgtype.Int4{Int: 42, Valid: true}, b) + } + + // Struct fields positionally via reflection + { + err := ct.Set([]interface{}{"foo", int32(42)}) + assert.NoError(t, err) + + s := struct { + A string + B int32 + }{} + + err = ct.AssignTo(&s) + if assert.NoError(t, err) { + assert.Equal(t, "foo", s.A) + assert.Equal(t, int32(42), s.B) + } + } +} + +func TestCompositeTypeTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists ct_test; + +create type ct_test as ( + a text, + b int4 +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type ct_test") + + var oid uint32 + err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid) + require.NoError(t, err) + + defer conn.Exec(context.Background(), "drop type ct_test") + + ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{ + {"a", pgtype.TextOID}, + {"b", pgtype.Int4OID}, + }, conn.ConnInfo()) + require.NoError(t, err) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) + + // Use simple protocol to force text or binary encoding + simpleProtocols := []bool{true, false} + + var a string + var b int32 + + for _, simpleProtocol := range simpleProtocols { + err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol), + pgtype.CompositeFields{"hi", int32(42)}, + ).Scan( + []interface{}{&a, &b}, + ) + if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) { + assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol) + assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol) + } + } +} + +// https://github.com/jackc/pgx/issues/874 +func TestCompositeTypeTextDecodeNested(t *testing.T) { + newCompositeType := func(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType { + fields := make([]pgtype.CompositeTypeField, len(fieldNames)) + for i, name := range fieldNames { + fields[i] = pgtype.CompositeTypeField{Name: name} + } + + rowType, err := pgtype.NewCompositeTypeValues(name, fields, vals) + require.NoError(t, err) + return rowType + } + + dimensionsType := func() pgtype.ValueTranscoder { + return newCompositeType( + "dimensions", + []string{"width", "height"}, + &pgtype.Int4{}, + &pgtype.Int4{}, + ) + } + productImageType := func() pgtype.ValueTranscoder { + return newCompositeType( + "product_image_type", + []string{"source", "dimensions"}, + &pgtype.Text{}, + dimensionsType(), + ) + } + productImageSetType := newCompositeType( + "product_image_set_type", + []string{"name", "orig_image", "images"}, + &pgtype.Text{}, + productImageType(), + pgtype.NewArrayType("product_image", 0, func() pgtype.ValueTranscoder { + return productImageType() + }), + ) + + err := productImageSetType.DecodeText(nil, []byte(`(name,"(img1,""(11,11)"")","{""(img2,\\""(22,22)\\"")"",""(img3,\\""(33,33)\\"")""}")`)) + require.NoError(t, err) +} + +func Example_composite() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Println(err) + return + } + + defer conn.Close(context.Background()) + _, err = conn.Exec(context.Background(), `drop type if exists mytype;`) + if err != nil { + fmt.Println(err) + return + } + + _, err = conn.Exec(context.Background(), `create type mytype as ( + a int4, + b text +);`) + if err != nil { + fmt.Println(err) + return + } + defer conn.Exec(context.Background(), "drop type mytype") + + var oid uint32 + err = conn.QueryRow(context.Background(), `select 'mytype'::regtype::oid`).Scan(&oid) + if err != nil { + fmt.Println(err) + return + } + + ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{ + {"a", pgtype.Int4OID}, + {"b", pgtype.TextOID}, + }, conn.ConnInfo()) + if err != nil { + fmt.Println(err) + return + } + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) + + var a int + var b *string + + err = conn.QueryRow(context.Background(), "select $1::mytype", []interface{}{2, "bar"}).Scan([]interface{}{&a, &b}) + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("First: a=%d b=%s\n", a, *b) + + err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan([]interface{}{&a, &b}) + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("Second: a=%d b=%v\n", a, b) + + scanTarget := []interface{}{&a, &b} + err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(&scanTarget) + E(err) + + fmt.Printf("Third: isNull=%v\n", scanTarget == nil) + + // Output: + // First: a=2 b=bar + // Second: a=1 b= + // Third: isNull=true +} diff --git a/pgtype/convert.go b/pgtype/convert.go new file mode 100644 index 00000000..21e208f5 --- /dev/null +++ b/pgtype/convert.go @@ -0,0 +1,472 @@ +package pgtype + +import ( + "database/sql" + "fmt" + "math" + "reflect" + "time" +) + +const ( + maxUint = ^uint(0) + maxInt = int(maxUint >> 1) + minInt = -maxInt - 1 +) + +// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 +func underlyingNumberType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Int: + convVal := int(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int8: + convVal := int8(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int16: + convVal := int16(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int32: + convVal := int32(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int64: + convVal := int64(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint: + convVal := uint(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint8: + convVal := uint8(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint16: + convVal := uint16(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint32: + convVal := uint32(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint64: + convVal := uint64(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float32: + convVal := float32(refVal.Float()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float64: + convVal := refVal.Float() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBoolType gets the underlying type that can be converted to Bool +func underlyingBoolType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Bool: + convVal := refVal.Bool() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBytesType gets the underlying type that can be converted to []byte +func underlyingBytesType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + if refVal.Type().Elem().Kind() == reflect.Uint8 { + convVal := refVal.Bytes() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + } + + return nil, false +} + +// underlyingStringType gets the underlying type that can be converted to String +func underlyingStringType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingPtrType dereferences a pointer +func underlyingPtrType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + return nil, false +} + +// underlyingTimeType gets the underlying type that can be converted to time.Time +func underlyingTimeType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + timeType := reflect.TypeOf(time.Time{}) + if refVal.Type().ConvertibleTo(timeType) { + return refVal.Convert(timeType).Interface(), true + } + + return nil, false +} + +// underlyingUUIDType gets the underlying type that can be converted to [16]byte +func underlyingUUIDType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return time.Time{}, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + uuidType := reflect.TypeOf([16]byte{}) + if refVal.Type().ConvertibleTo(uuidType) { + return refVal.Convert(uuidType).Interface(), true + } + + return nil, false +} + +// underlyingSliceType gets the underlying slice type +func underlyingSliceType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + baseSliceType := reflect.SliceOf(refVal.Type().Elem()) + if refVal.Type().ConvertibleTo(baseSliceType) { + convVal := refVal.Convert(baseSliceType) + return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() + } + } + + return nil, false +} + +func int64AssignTo(srcVal int64, srcValid bool, dst interface{}) error { + if srcValid { + switch v := dst.(type) { + case *int: + if srcVal < int64(minInt) { + return fmt.Errorf("%d is less than minimum value for int", srcVal) + } else if srcVal > int64(maxInt) { + return fmt.Errorf("%d is greater than maximum value for int", srcVal) + } + *v = int(srcVal) + case *int8: + if srcVal < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", srcVal) + } else if srcVal > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", srcVal) + } + *v = int8(srcVal) + case *int16: + if srcVal < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", srcVal) + } else if srcVal > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", srcVal) + } + *v = int16(srcVal) + case *int32: + if srcVal < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", srcVal) + } else if srcVal > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", srcVal) + } + *v = int32(srcVal) + case *int64: + if srcVal < math.MinInt64 { + return fmt.Errorf("%d is less than minimum value for int64", srcVal) + } else if srcVal > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for int64", srcVal) + } + *v = int64(srcVal) + case *uint: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint", srcVal) + } else if uint64(srcVal) > uint64(maxUint) { + return fmt.Errorf("%d is greater than maximum value for uint", srcVal) + } + *v = uint(srcVal) + case *uint8: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint8", srcVal) + } else if srcVal > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) + } + *v = uint8(srcVal) + case *uint16: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) + } + *v = uint16(srcVal) + case *uint32: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) + } + *v = uint32(srcVal) + case *uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint64", srcVal) + } + *v = uint64(srcVal) + case sql.Scanner: + return v.Scan(srcVal) + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return int64AssignTo(srcVal, srcValid, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if el.OverflowInt(int64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetInt(int64(srcVal)) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for %T", srcVal, dst) + } + if el.OverflowUint(uint64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetUint(uint64(srcVal)) + return nil + } + } + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Valid, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst) +} + +func float64AssignTo(srcVal float64, srcValid bool, dst interface{}) error { + if srcValid { + switch v := dst.(type) { + case *float32: + *v = float32(srcVal) + case *float64: + *v = srcVal + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return float64AssignTo(srcVal, srcValid, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i64 := int64(srcVal) + if float64(i64) == srcVal { + return int64AssignTo(i64, srcValid, dst) + } + } + } + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Valid, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst) +} + +func NullAssignTo(dst interface{}) error { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return &nullAssignmentError{dst: dst} + } + + dstVal := dstPtr.Elem() + + switch dstVal.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map: + dstVal.Set(reflect.Zero(dstVal.Type())) + return nil + } + + return &nullAssignmentError{dst: dst} +} + +var kindTypes map[reflect.Kind]reflect.Type + +func toInterface(dst reflect.Value, t reflect.Type) (interface{}, bool) { + nextDst := dst.Convert(t) + return nextDst.Interface(), dst.Type() != nextDst.Type() +} + +// GetAssignToDstType attempts to convert dst to something AssignTo can assign +// to. If dst is a pointer to pointer it allocates a value and returns the +// dereferences pointer. If dst is a named type such as *Foo where Foo is type +// Foo int16, it converts dst to *int16. +// +// GetAssignToDstType returns the converted dst and a bool representing if any +// change was made. +func GetAssignToDstType(dst interface{}) (interface{}, bool) { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return nil, false + } + + dstVal := dstPtr.Elem() + + // if dst is a pointer to pointer, allocate space try again with the dereferenced pointer + if dstVal.Kind() == reflect.Ptr { + dstVal.Set(reflect.New(dstVal.Type().Elem())) + return dstVal.Interface(), true + } + + // if dst is pointer to a base type that has been renamed + if baseValType, ok := kindTypes[dstVal.Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(baseValType)) + } + + if dstVal.Kind() == reflect.Slice { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(reflect.SliceOf(baseElemType))) + } + } + + if dstVal.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType))) + } + } + + if dstVal.Kind() == reflect.Struct { + if dstVal.Type().NumField() == 1 && dstVal.Type().Field(0).Anonymous { + dstPtr = dstVal.Field(0).Addr() + nested := dstVal.Type().Field(0).Type + if nested.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[nested.Elem().Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(nested.Len(), baseElemType))) + } + } + if _, ok := kindTypes[nested.Kind()]; ok && dstPtr.CanInterface() { + return dstPtr.Interface(), true + } + } + } + + return nil, false +} + +func init() { + kindTypes = map[reflect.Kind]reflect.Type{ + reflect.Bool: reflect.TypeOf(false), + reflect.Float32: reflect.TypeOf(float32(0)), + reflect.Float64: reflect.TypeOf(float64(0)), + reflect.Int: reflect.TypeOf(int(0)), + reflect.Int8: reflect.TypeOf(int8(0)), + reflect.Int16: reflect.TypeOf(int16(0)), + reflect.Int32: reflect.TypeOf(int32(0)), + reflect.Int64: reflect.TypeOf(int64(0)), + reflect.Uint: reflect.TypeOf(uint(0)), + reflect.Uint8: reflect.TypeOf(uint8(0)), + reflect.Uint16: reflect.TypeOf(uint16(0)), + reflect.Uint32: reflect.TypeOf(uint32(0)), + reflect.Uint64: reflect.TypeOf(uint64(0)), + reflect.String: reflect.TypeOf(""), + } +} diff --git a/pgtype/custom_composite_test.go b/pgtype/custom_composite_test.go new file mode 100644 index 00000000..86203828 --- /dev/null +++ b/pgtype/custom_composite_test.go @@ -0,0 +1,87 @@ +package pgtype_test + +import ( + "context" + "errors" + "fmt" + "os" + + "github.com/jackc/pgtype" + pgx "github.com/jackc/pgx/v4" +) + +type MyType struct { + a int32 // NULL will cause decoding error + b *string // there can be NULL in this position in SQL +} + +func (dst *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + return errors.New("NULL values can't be decoded. Scan into a &*MyType to handle NULLs") + } + + if err := (pgtype.CompositeFields{&dst.a, &dst.b}).DecodeBinary(ci, src); err != nil { + return err + } + + return nil +} + +func (src MyType) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) (newBuf []byte, err error) { + a := pgtype.Int4{src.a, true} + var b pgtype.Text + if src.b != nil { + b = pgtype.Text{*src.b, true} + } else { + b = pgtype.Text{} + } + + return (pgtype.CompositeFields{&a, &b}).EncodeBinary(ci, buf) +} + +func ptrS(s string) *string { + return &s +} + +func E(err error) { + if err != nil { + panic(err) + } +} + +// ExampleCustomCompositeTypes demonstrates how support for custom types mappable to SQL +// composites can be added. +func Example_customCompositeTypes() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + E(err) + + defer conn.Close(context.Background()) + _, err = conn.Exec(context.Background(), `drop type if exists mytype; + +create type mytype as ( + a int4, + b text +);`) + E(err) + defer conn.Exec(context.Background(), "drop type mytype") + + var result *MyType + + // Demonstrates both passing and reading back composite values + err = conn.QueryRow(context.Background(), "select $1::mytype", + pgx.QueryResultFormats{pgx.BinaryFormatCode}, MyType{1, ptrS("foo")}). + Scan(&result) + E(err) + + fmt.Printf("First row: a=%d b=%s\n", result.a, *result.b) + + // Because we scan into &*MyType, NULLs are handled generically by assigning nil to result + err = conn.QueryRow(context.Background(), "select NULL::mytype", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result) + E(err) + + fmt.Printf("Second row: %v\n", result) + + // Output: + // First row: a=1 b=foo + // Second row: +} diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go new file mode 100644 index 00000000..9d1cf822 --- /dev/null +++ b/pgtype/database_sql.go @@ -0,0 +1,41 @@ +package pgtype + +import ( + "database/sql/driver" + "errors" +) + +func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { + if valuer, ok := src.(driver.Valuer); ok { + return valuer.Value() + } + + if textEncoder, ok := src.(TextEncoder); ok { + buf, err := textEncoder.EncodeText(ci, nil) + if err != nil { + return nil, err + } + return string(buf), nil + } + + if binaryEncoder, ok := src.(BinaryEncoder); ok { + buf, err := binaryEncoder.EncodeBinary(ci, nil) + if err != nil { + return nil, err + } + return buf, nil + } + + return nil, errors.New("cannot convert to database/sql compatible value") +} + +func EncodeValueText(src TextEncoder) (interface{}, error) { + buf, err := src.EncodeText(nil, make([]byte, 0, 32)) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), err +} diff --git a/pgtype/date.go b/pgtype/date.go new file mode 100644 index 00000000..5b7f47e6 --- /dev/null +++ b/pgtype/date.go @@ -0,0 +1,264 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "time" + + "github.com/jackc/pgio" +) + +type Date struct { + Time time.Time + Valid bool + InfinityModifier InfinityModifier +} + +const ( + negativeInfinityDayOffset = -2147483648 + infinityDayOffset = 2147483647 +) + +func (dst *Date) Set(src interface{}) error { + if src == nil { + *dst = Date{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case time.Time: + *dst = Date{Time: value, Valid: true} + case string: + return dst.DecodeText(nil, []byte(value)) + case *time.Time: + if value == nil { + *dst = Date{} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Date{} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Date", value) + } + + return nil +} + +func (dst Date) Get() interface{} { + if !dst.Valid { + return nil + } + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time +} + +func (src *Date) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Date{} + return nil + } + + sbuf := string(src) + switch sbuf { + case "infinity": + *dst = Date{Valid: true, InfinityModifier: Infinity} + case "-infinity": + *dst = Date{Valid: true, InfinityModifier: -Infinity} + default: + t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) + if err != nil { + return err + } + + *dst = Date{Time: t, Valid: true} + } + + return nil +} + +func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Date{} + return nil + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for date: %v", len(src)) + } + + dayOffset := int32(binary.BigEndian.Uint32(src)) + + switch dayOffset { + case infinityDayOffset: + *dst = Date{Valid: true, InfinityModifier: Infinity} + case negativeInfinityDayOffset: + *dst = Date{Valid: true, InfinityModifier: -Infinity} + default: + t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) + *dst = Date{Time: t, Valid: true} + } + + return nil +} + +func (src Date) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return append(buf, s...), nil +} + +func (src Date) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var daysSinceDateEpoch int32 + switch src.InfinityModifier { + case None: + tUnix := time.Date(src.Time.Year(), src.Time.Month(), src.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() + dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + + secSinceDateEpoch := tUnix - dateEpoch + daysSinceDateEpoch = int32(secSinceDateEpoch / 86400) + case Infinity: + daysSinceDateEpoch = infinityDayOffset + case NegativeInfinity: + daysSinceDateEpoch = negativeInfinityDayOffset + } + + return pgio.AppendInt32(buf, daysSinceDateEpoch), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Date) Scan(src interface{}) error { + if src == nil { + *dst = Date{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + case time.Time: + *dst = Date{Time: src, Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Date) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil +} + +func (src Date) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (dst *Date) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Date{} + return nil + } + + switch *s { + case "infinity": + *dst = Date{Valid: true, InfinityModifier: Infinity} + case "-infinity": + *dst = Date{Valid: true, InfinityModifier: -Infinity} + default: + t, err := time.ParseInLocation("2006-01-02", *s, time.UTC) + if err != nil { + return err + } + + *dst = Date{Time: t, Valid: true} + } + + return nil +} diff --git a/pgtype/date_array.go b/pgtype/date_array.go new file mode 100644 index 00000000..9d3b32e2 --- /dev/null +++ b/pgtype/date_array.go @@ -0,0 +1,505 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + "time" + + "github.com/jackc/pgio" +) + +type DateArray struct { + Elements []Date + Dimensions []ArrayDimension + Valid bool +} + +func (dst *DateArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = DateArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []time.Time: + if value == nil { + *dst = DateArray{} + } else if len(value) == 0 { + *dst = DateArray{Valid: true} + } else { + elements := make([]Date, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = DateArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*time.Time: + if value == nil { + *dst = DateArray{} + } else if len(value) == 0 { + *dst = DateArray{Valid: true} + } else { + elements := make([]Date, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = DateArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Date: + if value == nil { + *dst = DateArray{} + } else if len(value) == 0 { + *dst = DateArray{Valid: true} + } else { + *dst = DateArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = DateArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for DateArray", src) + } + if elementsLength == 0 { + *dst = DateArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to DateArray", src) + } + + *dst = DateArray{ + Elements: make([]Date, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Date, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to DateArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *DateArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to DateArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in DateArray", err) + } + index++ + + return index, nil +} + +func (dst DateArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *DateArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from DateArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from DateArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = DateArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Date + + if len(uta.Elements) > 0 { + elements = make([]Date, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Date + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = DateArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = DateArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = DateArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Date, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = DateArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("date"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "date") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *DateArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src DateArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/date_array_test.go b/pgtype/date_array_test.go new file mode 100644 index 00000000..421427cd --- /dev/null +++ b/pgtype/date_array_test.go @@ -0,0 +1,327 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestDateArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "date[]", []interface{}{ + &pgtype.DateArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.DateArray{}, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestDateArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.DateArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.DateArray{}, + }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.DateArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestDateArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time + + simpleTests := []struct { + src pgtype.DateArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.DateArray{}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + { + src: pgtype.DateArray{Valid: true}, + dst: &timeSlice, + expected: []time.Time{}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.DateArray + dst interface{} + }{ + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &timeSlice, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &timeArrayDim2, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &timeSlice, + }, + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &timeArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/date_test.go b/pgtype/date_test.go new file mode 100644 index 00000000..87425540 --- /dev/null +++ b/pgtype/date_test.go @@ -0,0 +1,168 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestDateTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "date", []interface{}{ + &pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Date{}, + &pgtype.Date{Valid: true, InfinityModifier: pgtype.Infinity}, + &pgtype.Date{Valid: true, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Date) + bt := b.(pgtype.Date) + + return at.Time.Equal(bt.Time) && at.Valid == bt.Valid && at.InfinityModifier == bt.InfinityModifier + }) +} + +func TestDateSet(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Date + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: "1999-12-31", result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}}, + } + + for i, tt := range successfulTests { + var d pgtype.Date + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestDateAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Date + dst interface{} + expected interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Date{Time: time.Time{}}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Date + dst interface{} + expected interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Date + dst interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Valid: true}, dst: &tim}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Valid: true}, dst: &tim}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} + +func TestDateMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Date + result string + }{ + {source: pgtype.Date{}, result: "null"}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Valid: true}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29\""}, + {source: pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""}, + {source: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestDateUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Date + }{ + {source: "null", result: pgtype.Date{}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"infinity\"", result: pgtype.Date{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: "\"-infinity\"", result: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Date + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r.Time.Year() != tt.result.Time.Year() || r.Time.Month() != tt.result.Time.Month() || r.Time.Day() != tt.result.Time.Day() || r.Valid != tt.result.Valid || r.InfinityModifier != tt.result.InfinityModifier { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/daterange.go b/pgtype/daterange.go new file mode 100644 index 00000000..8b0c03f1 --- /dev/null +++ b/pgtype/daterange.go @@ -0,0 +1,257 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +type Daterange struct { + Lower Date + Upper Date + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (dst *Daterange) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Daterange{} + return nil + } + + switch value := src.(type) { + case Daterange: + *dst = value + case *Daterange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to Daterange", src) + } + + return nil +} + +func (src Daterange) Get() interface{} { + if !src.Valid { + return nil + } + return src +} + +func (src *Daterange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Daterange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Daterange{} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Daterange{Valid: true} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Daterange{} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Daterange{Valid: true} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Daterange) Scan(src interface{}) error { + if src == nil { + *dst = Daterange{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Daterange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/daterange_test.go b/pgtype/daterange_test.go new file mode 100644 index 00000000..830942d0 --- /dev/null +++ b/pgtype/daterange_test.go @@ -0,0 +1,133 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestDaterangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ + &pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + &pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + &pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + &pgtype.Daterange{}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Daterange) + b := bb.(pgtype.Daterange) + + return a.Valid == b.Valid && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Valid == b.Lower.Valid && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Valid == b.Upper.Valid && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} + +func TestDaterangeNormalize(t *testing.T) { + testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ + { + SQL: "select daterange('2010-01-01', '2010-01-11', '(]')", + Value: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(2010, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2010, 1, 12, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Daterange) + b := bb.(pgtype.Daterange) + + return a.Valid == b.Valid && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Valid == b.Lower.Valid && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Valid == b.Upper.Valid && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} + +func TestDaterangeSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Daterange + }{ + { + source: nil, + result: pgtype.Daterange{}, + }, + { + source: &pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + result: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + { + source: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + result: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + { + source: "[1990-12-31,2028-01-01)", + result: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Daterange + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/enum_array.go b/pgtype/enum_array.go new file mode 100644 index 00000000..dbfb211d --- /dev/null +++ b/pgtype/enum_array.go @@ -0,0 +1,418 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "fmt" + "reflect" +) + +type EnumArray struct { + Elements []GenericText + Dimensions []ArrayDimension + Valid bool +} + +func (dst *EnumArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = EnumArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []string: + if value == nil { + *dst = EnumArray{} + } else if len(value) == 0 { + *dst = EnumArray{Valid: true} + } else { + elements := make([]GenericText, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = EnumArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*string: + if value == nil { + *dst = EnumArray{} + } else if len(value) == 0 { + *dst = EnumArray{Valid: true} + } else { + elements := make([]GenericText, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = EnumArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []GenericText: + if value == nil { + *dst = EnumArray{} + } else if len(value) == 0 { + *dst = EnumArray{Valid: true} + } else { + *dst = EnumArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = EnumArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for EnumArray", src) + } + if elementsLength == 0 { + *dst = EnumArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to EnumArray", src) + } + + *dst = EnumArray{ + Elements: make([]GenericText, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]GenericText, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to EnumArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *EnumArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to EnumArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in EnumArray", err) + } + index++ + + return index, nil +} + +func (dst EnumArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *EnumArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from EnumArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from EnumArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = EnumArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []GenericText + + if len(uta.Elements) > 0 { + elements = make([]GenericText, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem GenericText + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = EnumArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (src EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *EnumArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src EnumArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/enum_array_test.go b/pgtype/enum_array_test.go new file mode 100644 index 00000000..7d0ff864 --- /dev/null +++ b/pgtype/enum_array_test.go @@ -0,0 +1,281 @@ +package pgtype_test + +import ( + "context" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestEnumArrayTranscode(t *testing.T) { + setupConn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, setupConn) + + if _, err := setupConn.Exec(context.Background(), "drop type if exists color"); err != nil { + t.Fatal(err) + } + if _, err := setupConn.Exec(context.Background(), "create type color as enum ('red', 'green', 'blue')"); err != nil { + t.Fatal(err) + } + + testutil.TestSuccessfulTranscode(t, "color[]", []interface{}{ + &pgtype.EnumArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "red", Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.EnumArray{}, + &pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "red", Valid: true}, + {String: "green", Valid: true}, + {String: "blue", Valid: true}, + {String: "red", Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestEnumArrayArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.EnumArray + }{ + { + source: []string{"foo"}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]string)(nil)), + result: pgtype.EnumArray{}, + }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.EnumArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestEnumArrayArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string + + simpleTests := []struct { + src pgtype.EnumArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.EnumArray{}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + { + src: pgtype.EnumArray{Valid: true}, + dst: &stringSlice, + expected: []string{}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.EnumArray + dst interface{} + }{ + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &stringSlice, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &stringArrayDim2, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &stringSlice, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &stringArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/enum_type.go b/pgtype/enum_type.go new file mode 100644 index 00000000..73ee3823 --- /dev/null +++ b/pgtype/enum_type.go @@ -0,0 +1,158 @@ +package pgtype + +import "fmt" + +// EnumType represents a enum type. While it implements Value, this is only in service of its type conversion duties +// when registered as a data type in a ConnType. It should not be used directly as a Value. +type EnumType struct { + value string + valid bool + + typeName string // PostgreSQL type name + members []string // enum members + membersMap map[string]string // map to quickly lookup member and reuse string instead of allocating +} + +// NewEnumType initializes a new EnumType. It retains a read-only reference to members. members must not be changed. +func NewEnumType(typeName string, members []string) *EnumType { + et := &EnumType{typeName: typeName, members: members} + et.membersMap = make(map[string]string, len(members)) + for _, m := range members { + et.membersMap[m] = m + } + return et +} + +func (et *EnumType) NewTypeValue() Value { + return &EnumType{ + value: et.value, + valid: et.valid, + + typeName: et.typeName, + members: et.members, + membersMap: et.membersMap, + } +} + +func (et *EnumType) TypeName() string { + return et.typeName +} + +func (et *EnumType) Members() []string { + return et.members +} + +// Set assigns src to dst. Set purposely does not check that src is a member. This allows continued error free +// operation in the event the PostgreSQL enum type is modified during a connection. +func (dst *EnumType) Set(src interface{}) error { + if src == nil { + dst.valid = false + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case string: + dst.value = value + dst.valid = true + case *string: + if value == nil { + dst.valid = false + } else { + dst.value = *value + dst.valid = true + } + case []byte: + if value == nil { + dst.valid = false + } else { + dst.value = string(value) + dst.valid = true + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to enum %s", value, dst.typeName) + } + + return nil +} + +func (dst EnumType) Get() interface{} { + if !dst.valid { + return nil + } + return dst.value +} + +func (src *EnumType) AssignTo(dst interface{}) error { + if !src.valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *string: + *v = src.value + return nil + case *[]byte: + *v = make([]byte, len(src.value)) + copy(*v, src.value) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +func (EnumType) PreferredResultFormat() int16 { + return TextFormatCode +} + +func (dst *EnumType) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + dst.valid = false + return nil + } + + // Lookup the string in membersMap to avoid an allocation. + if s, found := dst.membersMap[string(src)]; found { + dst.value = s + } else { + // If an enum type is modified after the initial connection it is possible to receive an unexpected value. + // Gracefully handle this situation. Purposely NOT modifying members and membersMap to allow for sharing members + // and membersMap between connections. + dst.value = string(src) + } + dst.valid = true + + return nil +} + +func (dst *EnumType) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) +} + +func (EnumType) PreferredParamFormat() int16 { + return TextFormatCode +} + +func (src EnumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.valid { + return nil, nil + } + + return append(buf, src.value...), nil +} + +func (src EnumType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) +} diff --git a/pgtype/enum_type_test.go b/pgtype/enum_type_test.go new file mode 100644 index 00000000..4dd88f2a --- /dev/null +++ b/pgtype/enum_type_test.go @@ -0,0 +1,148 @@ +package pgtype_test + +import ( + "bytes" + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupEnum(t *testing.T, conn *pgx.Conn) *pgtype.EnumType { + _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") + require.NoError(t, err) + + _, err = conn.Exec(context.Background(), "create type pgtype_enum_color as enum ('blue', 'green', 'purple');") + require.NoError(t, err) + + var oid uint32 + err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "pgtype_enum_color").Scan(&oid) + require.NoError(t, err) + + et := pgtype.NewEnumType("pgtype_enum_color", []string{"blue", "green", "purple"}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "pgtype_enum_color", OID: oid}) + + return et +} + +func cleanupEnum(t *testing.T, conn *pgx.Conn) { + _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") + require.NoError(t, err) +} + +func TestEnumTypeTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + setupEnum(t, conn) + defer cleanupEnum(t, conn) + + var dst string + err := conn.QueryRow(context.Background(), "select $1::pgtype_enum_color", "blue").Scan(&dst) + require.NoError(t, err) + require.EqualValues(t, "blue", dst) +} + +func TestEnumTypeSet(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + enumType := setupEnum(t, conn) + defer cleanupEnum(t, conn) + + successfulTests := []struct { + source interface{} + result interface{} + }{ + {source: "blue", result: "blue"}, + {source: _string("green"), result: "green"}, + {source: (*string)(nil), result: nil}, + } + + for i, tt := range successfulTests { + err := enumType.Set(tt.source) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.result, enumType.Get(), "%d", i) + } +} + +func TestEnumTypeAssignTo(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + enumType := setupEnum(t, conn) + defer cleanupEnum(t, conn) + + { + var s string + + err := enumType.Set("blue") + require.NoError(t, err) + + err = enumType.AssignTo(&s) + require.NoError(t, err) + + assert.EqualValues(t, "blue", s) + } + + { + var ps *string + + err := enumType.Set("blue") + require.NoError(t, err) + + err = enumType.AssignTo(&ps) + require.NoError(t, err) + + assert.EqualValues(t, "blue", *ps) + } + + { + var ps *string + + err := enumType.Set(nil) + require.NoError(t, err) + + err = enumType.AssignTo(&ps) + require.NoError(t, err) + + assert.EqualValues(t, (*string)(nil), ps) + } + + var buf []byte + bytesTests := []struct { + src interface{} + dst *[]byte + expected []byte + }{ + {src: "blue", dst: &buf, expected: []byte("blue")}, + {src: nil, dst: &buf, expected: nil}, + } + + for i, tt := range bytesTests { + err := enumType.Set(tt.src) + require.NoError(t, err, "%d", i) + + err = enumType.AssignTo(tt.dst) + require.NoError(t, err, "%d", i) + + if bytes.Compare(*tt.dst, tt.expected) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) + } + } + + { + var s string + + err := enumType.Set(nil) + require.NoError(t, err) + + err = enumType.AssignTo(&s) + require.Error(t, err) + } + +} diff --git a/pgtype/float4.go b/pgtype/float4.go new file mode 100644 index 00000000..36c46346 --- /dev/null +++ b/pgtype/float4.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Float4 struct { + Float float32 + Valid bool +} + +func (dst *Float4) Set(src interface{}) error { + if src == nil { + *dst = Float4{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case float32: + *dst = Float4{Float: value, Valid: true} + case float64: + *dst = Float4{Float: float32(value), Valid: true} + case int8: + *dst = Float4{Float: float32(value), Valid: true} + case uint8: + *dst = Float4{Float: float32(value), Valid: true} + case int16: + *dst = Float4{Float: float32(value), Valid: true} + case uint16: + *dst = Float4{Float: float32(value), Valid: true} + case int32: + f32 := float32(value) + if int32(f32) == value { + *dst = Float4{Float: f32, Valid: true} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case uint32: + f32 := float32(value) + if uint32(f32) == value { + *dst = Float4{Float: f32, Valid: true} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case int64: + f32 := float32(value) + if int64(f32) == value { + *dst = Float4{Float: f32, Valid: true} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case uint64: + f32 := float32(value) + if uint64(f32) == value { + *dst = Float4{Float: f32, Valid: true} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case int: + f32 := float32(value) + if int(f32) == value { + *dst = Float4{Float: f32, Valid: true} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case uint: + f32 := float32(value) + if uint(f32) == value { + *dst = Float4{Float: f32, Valid: true} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case string: + num, err := strconv.ParseFloat(value, 32) + if err != nil { + return err + } + *dst = Float4{Float: float32(num), Valid: true} + case *float64: + if value == nil { + *dst = Float4{} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Float4{} + } else { + return dst.Set(*value) + } + case *int8: + if value == nil { + *dst = Float4{} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Float4{} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Float4{} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Float4{} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Float4{} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Float4{} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Float4{} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Float4{} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Float4{} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Float4{} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Float4{} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float8", value) + } + + return nil +} + +func (dst Float4) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.Float +} + +func (src *Float4) AssignTo(dst interface{}) error { + return float64AssignTo(float64(src.Float), src.Valid, dst) +} + +func (dst *Float4) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float4{} + return nil + } + + n, err := strconv.ParseFloat(string(src), 32) + if err != nil { + return err + } + + *dst = Float4{Float: float32(n), Valid: true} + return nil +} + +func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float4{} + return nil + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) + } + + n := int32(binary.BigEndian.Uint32(src)) + + *dst = Float4{Float: math.Float32frombits(uint32(n)), Valid: true} + return nil +} + +func (src Float4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 32)...) + return buf, nil +} + +func (src Float4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = pgio.AppendUint32(buf, math.Float32bits(src.Float)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float4) Scan(src interface{}) error { + if src == nil { + *dst = Float4{} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float4{Float: float32(src), Valid: true} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float4) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return float64(src.Float), nil +} diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go new file mode 100644 index 00000000..dcf6c1f7 --- /dev/null +++ b/pgtype/float4_array.go @@ -0,0 +1,504 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type Float4Array struct { + Elements []Float4 + Dimensions []ArrayDimension + Valid bool +} + +func (dst *Float4Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Float4Array{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []float32: + if value == nil { + *dst = Float4Array{} + } else if len(value) == 0 { + *dst = Float4Array{Valid: true} + } else { + elements := make([]Float4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*float32: + if value == nil { + *dst = Float4Array{} + } else if len(value) == 0 { + *dst = Float4Array{Valid: true} + } else { + elements := make([]Float4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Float4: + if value == nil { + *dst = Float4Array{} + } else if len(value) == 0 { + *dst = Float4Array{Valid: true} + } else { + *dst = Float4Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Float4Array{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for Float4Array", src) + } + if elementsLength == 0 { + *dst = Float4Array{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float4Array", src) + } + + *dst = Float4Array{ + Elements: make([]Float4, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Float4, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to Float4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *Float4Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to Float4Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in Float4Array", err) + } + index++ + + return index, nil +} + +func (dst Float4Array) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Float4Array) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]float32: + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float32: + *v = make([]*float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from Float4Array") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from Float4Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float4Array{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Float4 + + if len(uta.Elements) > 0 { + elements = make([]Float4, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Float4 + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Float4Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float4Array{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Float4Array{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Float4, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Float4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("float4"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "float4") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float4Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/float4_array_test.go b/pgtype/float4_array_test.go new file mode 100644 index 00000000..9b401ac8 --- /dev/null +++ b/pgtype/float4_array_test.go @@ -0,0 +1,282 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestFloat4ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float4[]", []interface{}{ + &pgtype.Float4Array{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.Float4Array{}, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {}, + {Float: 6, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestFloat4ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float4Array + }{ + { + source: []float32{1}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]float32)(nil)), + result: pgtype.Float4Array{}, + }, + { + source: [][]float32{{1}, {2}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]float32{{1}, {2}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Float4Array + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat4ArrayAssignTo(t *testing.T) { + var float32Slice []float32 + var namedFloat32Slice _float32Slice + var float32SliceDim2 [][]float32 + var float32SliceDim4 [][][][]float32 + var float32ArrayDim2 [2][1]float32 + var float32ArrayDim4 [2][1][1][3]float32 + + simpleTests := []struct { + src pgtype.Float4Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1.23, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &float32Slice, + expected: []float32{1.23}, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1.23, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &namedFloat32Slice, + expected: _float32Slice{1.23}, + }, + { + src: pgtype.Float4Array{}, + dst: &float32Slice, + expected: (([]float32)(nil)), + }, + { + src: pgtype.Float4Array{Valid: true}, + dst: &float32Slice, + expected: []float32{}, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + expected: [][]float32{{1}, {2}}, + dst: &float32SliceDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float32SliceDim4, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + expected: [2][1]float32{{1}, {2}}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{ + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float32ArrayDim4, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float4Array + dst interface{} + }{ + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &float32Slice, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &float32Slice, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &float32ArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go new file mode 100644 index 00000000..191df65e --- /dev/null +++ b/pgtype/float4_test.go @@ -0,0 +1,149 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestFloat4Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float4", []interface{}{ + &pgtype.Float4{Float: -1, Valid: true}, + &pgtype.Float4{Float: 0, Valid: true}, + &pgtype.Float4{Float: 0.00001, Valid: true}, + &pgtype.Float4{Float: 1, Valid: true}, + &pgtype.Float4{Float: 9999.99, Valid: true}, + &pgtype.Float4{Float: 0}, + }) +} + +func TestFloat4Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float4 + }{ + {source: float32(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: float64(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: int8(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: int16(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: int32(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: int64(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: int8(-1), result: pgtype.Float4{Float: -1, Valid: true}}, + {source: int16(-1), result: pgtype.Float4{Float: -1, Valid: true}}, + {source: int32(-1), result: pgtype.Float4{Float: -1, Valid: true}}, + {source: int64(-1), result: pgtype.Float4{Float: -1, Valid: true}}, + {source: uint8(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: uint16(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: uint32(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: uint64(1), result: pgtype.Float4{Float: 1, Valid: true}}, + {source: "1", result: pgtype.Float4{Float: 1, Valid: true}}, + {source: _int8(1), result: pgtype.Float4{Float: 1, Valid: true}}, + } + + for i, tt := range successfulTests { + var r pgtype.Float4 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat4AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src pgtype.Float4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &f32, expected: float32(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &f64, expected: float64(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &i16, expected: int16(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &i32, expected: int32(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &i64, expected: int64(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &i, expected: int(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &ui, expected: uint(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float4{Float: 0}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Float4{Float: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Float4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &pf32, expected: float32(42)}, + {src: pgtype.Float4{Float: 42, Valid: true}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float4 + dst interface{} + }{ + {src: pgtype.Float4{Float: 150, Valid: true}, dst: &i8}, + {src: pgtype.Float4{Float: 40000, Valid: true}, dst: &i16}, + {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui8}, + {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui16}, + {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui32}, + {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui64}, + {src: pgtype.Float4{Float: -1, Valid: true}, dst: &ui}, + {src: pgtype.Float4{Float: 0}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/float8.go b/pgtype/float8.go new file mode 100644 index 00000000..1038d283 --- /dev/null +++ b/pgtype/float8.go @@ -0,0 +1,258 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Float8 struct { + Float float64 + Valid bool +} + +func (dst *Float8) Set(src interface{}) error { + if src == nil { + *dst = Float8{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case float32: + *dst = Float8{Float: float64(value), Valid: true} + case float64: + *dst = Float8{Float: value, Valid: true} + case int8: + *dst = Float8{Float: float64(value), Valid: true} + case uint8: + *dst = Float8{Float: float64(value), Valid: true} + case int16: + *dst = Float8{Float: float64(value), Valid: true} + case uint16: + *dst = Float8{Float: float64(value), Valid: true} + case int32: + *dst = Float8{Float: float64(value), Valid: true} + case uint32: + *dst = Float8{Float: float64(value), Valid: true} + case int64: + f64 := float64(value) + if int64(f64) == value { + *dst = Float8{Float: f64, Valid: true} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case uint64: + f64 := float64(value) + if uint64(f64) == value { + *dst = Float8{Float: f64, Valid: true} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case int: + f64 := float64(value) + if int(f64) == value { + *dst = Float8{Float: f64, Valid: true} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case uint: + f64 := float64(value) + if uint(f64) == value { + *dst = Float8{Float: f64, Valid: true} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case string: + num, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + *dst = Float8{Float: float64(num), Valid: true} + case *float64: + if value == nil { + *dst = Float8{} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Float8{} + } else { + return dst.Set(*value) + } + case *int8: + if value == nil { + *dst = Float8{} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Float8{} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Float8{} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Float8{} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Float8{} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Float8{} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Float8{} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Float8{} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Float8{} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Float8{} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Float8{} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float8", value) + } + + return nil +} + +func (dst Float8) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.Float +} + +func (src *Float8) AssignTo(dst interface{}) error { + return float64AssignTo(src.Float, src.Valid, dst) +} + +func (dst *Float8) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float8{} + return nil + } + + n, err := strconv.ParseFloat(string(src), 64) + if err != nil { + return err + } + + *dst = Float8{Float: n, Valid: true} + return nil +} + +func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float8{} + return nil + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for float4: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + + *dst = Float8{Float: math.Float64frombits(uint64(n)), Valid: true} + return nil +} + +func (src Float8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 64)...) + return buf, nil +} + +func (src Float8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.Float)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float8) Scan(src interface{}) error { + if src == nil { + *dst = Float8{} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float8{Float: src, Valid: true} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float8) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return src.Float, nil +} diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go new file mode 100644 index 00000000..5e85e236 --- /dev/null +++ b/pgtype/float8_array.go @@ -0,0 +1,504 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type Float8Array struct { + Elements []Float8 + Dimensions []ArrayDimension + Valid bool +} + +func (dst *Float8Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Float8Array{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []float64: + if value == nil { + *dst = Float8Array{} + } else if len(value) == 0 { + *dst = Float8Array{Valid: true} + } else { + elements := make([]Float8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*float64: + if value == nil { + *dst = Float8Array{} + } else if len(value) == 0 { + *dst = Float8Array{Valid: true} + } else { + elements := make([]Float8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Float8: + if value == nil { + *dst = Float8Array{} + } else if len(value) == 0 { + *dst = Float8Array{Valid: true} + } else { + *dst = Float8Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Float8Array{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for Float8Array", src) + } + if elementsLength == 0 { + *dst = Float8Array{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float8Array", src) + } + + *dst = Float8Array{ + Elements: make([]Float8, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Float8, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to Float8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *Float8Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to Float8Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in Float8Array", err) + } + index++ + + return index, nil +} + +func (dst Float8Array) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Float8Array) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]float64: + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float64: + *v = make([]*float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from Float8Array") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from Float8Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float8Array{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Float8 + + if len(uta.Elements) > 0 { + elements = make([]Float8, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Float8 + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Float8Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float8Array{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Float8Array{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Float8, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Float8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("float8"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "float8") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float8Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/float8_array_test.go b/pgtype/float8_array_test.go new file mode 100644 index 00000000..52209238 --- /dev/null +++ b/pgtype/float8_array_test.go @@ -0,0 +1,258 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestFloat8ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float8[]", []interface{}{ + &pgtype.Float8Array{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.Float8Array{}, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {}, + {Float: 6, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestFloat8ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float8Array + }{ + { + source: []float64{1}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]float64)(nil)), + result: pgtype.Float8Array{}, + }, + { + source: [][]float64{{1}, {2}}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Float8Array + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat8ArrayAssignTo(t *testing.T) { + var float64Slice []float64 + var namedFloat64Slice _float64Slice + var float64SliceDim2 [][]float64 + var float64SliceDim4 [][][][]float64 + var float64ArrayDim2 [2][1]float64 + var float64ArrayDim4 [2][1][1][3]float64 + + simpleTests := []struct { + src pgtype.Float8Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1.23, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &float64Slice, + expected: []float64{1.23}, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1.23, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &namedFloat64Slice, + expected: _float64Slice{1.23}, + }, + { + src: pgtype.Float8Array{}, + dst: &float64Slice, + expected: (([]float64)(nil)), + }, + { + src: pgtype.Float8Array{Valid: true}, + dst: &float64Slice, + expected: []float64{}, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + expected: [][]float64{{1}, {2}}, + dst: &float64SliceDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + expected: [][][][]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float64SliceDim4, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + expected: [2][1]float64{{1}, {2}}, + dst: &float64ArrayDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{ + {Float: 1, Valid: true}, + {Float: 2, Valid: true}, + {Float: 3, Valid: true}, + {Float: 4, Valid: true}, + {Float: 5, Valid: true}, + {Float: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + expected: [2][1][1][3]float64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &float64ArrayDim4, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float8Array + dst interface{} + }{ + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &float64Slice, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &float64ArrayDim2, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &float64Slice, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Valid: true}, {Float: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &float64ArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go new file mode 100644 index 00000000..dcc45879 --- /dev/null +++ b/pgtype/float8_test.go @@ -0,0 +1,149 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestFloat8Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ + &pgtype.Float8{Float: -1, Valid: true}, + &pgtype.Float8{Float: 0, Valid: true}, + &pgtype.Float8{Float: 0.00001, Valid: true}, + &pgtype.Float8{Float: 1, Valid: true}, + &pgtype.Float8{Float: 9999.99, Valid: true}, + &pgtype.Float8{Float: 0}, + }) +} + +func TestFloat8Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float8 + }{ + {source: float32(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: float64(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: int8(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: int16(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: int32(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: int64(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: int8(-1), result: pgtype.Float8{Float: -1, Valid: true}}, + {source: int16(-1), result: pgtype.Float8{Float: -1, Valid: true}}, + {source: int32(-1), result: pgtype.Float8{Float: -1, Valid: true}}, + {source: int64(-1), result: pgtype.Float8{Float: -1, Valid: true}}, + {source: uint8(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: uint16(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: uint32(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: uint64(1), result: pgtype.Float8{Float: 1, Valid: true}}, + {source: "1", result: pgtype.Float8{Float: 1, Valid: true}}, + {source: _int8(1), result: pgtype.Float8{Float: 1, Valid: true}}, + } + + for i, tt := range successfulTests { + var r pgtype.Float8 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat8AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src pgtype.Float8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &f32, expected: float32(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &f64, expected: float64(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &i16, expected: int16(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &i32, expected: int32(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &i64, expected: int64(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &i, expected: int(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &ui, expected: uint(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float8{Float: 0}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Float8{Float: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Float8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &pf32, expected: float32(42)}, + {src: pgtype.Float8{Float: 42, Valid: true}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float8 + dst interface{} + }{ + {src: pgtype.Float8{Float: 150, Valid: true}, dst: &i8}, + {src: pgtype.Float8{Float: 40000, Valid: true}, dst: &i16}, + {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui8}, + {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui16}, + {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui32}, + {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui64}, + {src: pgtype.Float8{Float: -1, Valid: true}, dst: &ui}, + {src: pgtype.Float8{Float: 0}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/generic_binary.go b/pgtype/generic_binary.go new file mode 100644 index 00000000..76a1d351 --- /dev/null +++ b/pgtype/generic_binary.go @@ -0,0 +1,39 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// GenericBinary is a placeholder for binary format values that no other type exists +// to handle. +type GenericBinary Bytea + +func (dst *GenericBinary) Set(src interface{}) error { + return (*Bytea)(dst).Set(src) +} + +func (dst GenericBinary) Get() interface{} { + return (Bytea)(dst).Get() +} + +func (src *GenericBinary) AssignTo(dst interface{}) error { + return (*Bytea)(src).AssignTo(dst) +} + +func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Bytea)(dst).DecodeBinary(ci, src) +} + +func (src GenericBinary) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Bytea)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *GenericBinary) Scan(src interface{}) error { + return (*Bytea)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src GenericBinary) Value() (driver.Value, error) { + return (Bytea)(src).Value() +} diff --git a/pgtype/generic_text.go b/pgtype/generic_text.go new file mode 100644 index 00000000..dbf5b47e --- /dev/null +++ b/pgtype/generic_text.go @@ -0,0 +1,39 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// GenericText is a placeholder for text format values that no other type exists +// to handle. +type GenericText Text + +func (dst *GenericText) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst GenericText) Get() interface{} { + return (Text)(dst).Get() +} + +func (src *GenericText) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (src GenericText) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeText(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *GenericText) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src GenericText) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/pgtype/go.mod b/pgtype/go.mod new file mode 100644 index 00000000..b2f1cc10 --- /dev/null +++ b/pgtype/go.mod @@ -0,0 +1,10 @@ +module github.com/jackc/pgtype + +go 1.13 + +require ( + github.com/jackc/pgconn v1.10.1 + github.com/jackc/pgio v1.0.0 + github.com/jackc/pgx/v4 v4.14.2-0.20211129172902-cf0de913ee8f + github.com/stretchr/testify v1.7.0 +) diff --git a/pgtype/go.sum b/pgtype/go.sum new file mode 100644 index 00000000..2a835726 --- /dev/null +++ b/pgtype/go.sum @@ -0,0 +1,180 @@ +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= +github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.10.1 h1:DzdIHIjG1AxGwoEEqS+mGsURyjt4enSmqzACXvVzOT8= +github.com/jackc/pgconn v1.10.1/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.2.0 h1:r7JypeP2D3onoQTCxWdTpCtJ4D+qpKr0TxvoyMhZ5ns= +github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= +github.com/jackc/pgtype v1.9.1/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= +github.com/jackc/pgx/v4 v4.14.2-0.20211129172902-cf0de913ee8f h1:Y3Es3mIYatTvP4CXPXfmJtHWe8eq4E8owY6Fq61hEik= +github.com/jackc/pgx/v4 v4.14.2-0.20211129172902-cf0de913ee8f/go.mod h1:RgDuE4Z34o7XE92RpLsvFiOEfrAUT0Xt2KxvX73W06M= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/pgtype/hstore.go b/pgtype/hstore.go new file mode 100644 index 00000000..25406a74 --- /dev/null +++ b/pgtype/hstore.go @@ -0,0 +1,446 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "errors" + "fmt" + "strings" + "unicode" + "unicode/utf8" + + "github.com/jackc/pgio" +) + +// Hstore represents an hstore column that can be null or have null values +// associated with its keys. +type Hstore struct { + Map map[string]Text + Valid bool +} + +func (dst *Hstore) Set(src interface{}) error { + if src == nil { + *dst = Hstore{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case map[string]string: + m := make(map[string]Text, len(value)) + for k, v := range value { + m[k] = Text{String: v, Valid: true} + } + *dst = Hstore{Map: m, Valid: true} + case map[string]*string: + m := make(map[string]Text, len(value)) + for k, v := range value { + if v == nil { + m[k] = Text{} + } else { + m[k] = Text{String: *v, Valid: true} + } + } + *dst = Hstore{Map: m, Valid: true} + default: + return fmt.Errorf("cannot convert %v to Hstore", src) + } + + return nil +} + +func (dst Hstore) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.Map +} + +func (src *Hstore) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *map[string]string: + *v = make(map[string]string, len(src.Map)) + for k, val := range src.Map { + if !val.Valid { + return fmt.Errorf("cannot decode %#v into %T", src, dst) + } + (*v)[k] = val.String + } + return nil + case *map[string]*string: + *v = make(map[string]*string, len(src.Map)) + for k, val := range src.Map { + if val.Valid { + (*v)[k] = &val.String + } else { + (*v)[k] = nil + } + } + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Hstore{} + return nil + } + + keys, values, err := parseHstore(string(src)) + if err != nil { + return err + } + + m := make(map[string]Text, len(keys)) + for i := range keys { + m[keys[i]] = values[i] + } + + *dst = Hstore{Map: m, Valid: true} + return nil +} + +func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Hstore{} + return nil + } + + rp := 0 + + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + m := make(map[string]Text, pairCount) + + for i := 0; i < pairCount; i++ { + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + if len(src[rp:]) < keyLen { + return fmt.Errorf("hstore incomplete %v", src) + } + key := string(src[rp : rp+keyLen]) + rp += keyLen + + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + var valueBuf []byte + if valueLen >= 0 { + valueBuf = src[rp : rp+valueLen] + rp += valueLen + } + + var value Text + err := value.DecodeBinary(ci, valueBuf) + if err != nil { + return err + } + m[key] = value + } + + *dst = Hstore{Map: m, Valid: true} + + return nil +} + +func (src Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + firstPair := true + + inElemBuf := make([]byte, 0, 32) + for k, v := range src.Map { + if firstPair { + firstPair = false + } else { + buf = append(buf, ',') + } + + buf = append(buf, quoteHstoreElementIfNeeded(k)...) + buf = append(buf, "=>"...) + + elemBuf, err := v.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + + if elemBuf == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, quoteHstoreElementIfNeeded(string(elemBuf))...) + } + } + + return buf, nil +} + +func (src Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = pgio.AppendInt32(buf, int32(len(src.Map))) + + var err error + for k, v := range src.Map { + buf = pgio.AppendInt32(buf, int32(len(k))) + buf = append(buf, k...) + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := v.EncodeText(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, err +} + +var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteHstoreElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func quoteHstoreElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) { + return quoteArrayElement(src) + } + return src +} + +const ( + hsPre = iota + hsKey + hsSep + hsVal + hsNul + hsNext +) + +type hstoreParser struct { + str string + pos int +} + +func newHSP(in string) *hstoreParser { + return &hstoreParser{ + pos: 0, + str: in, + } +} + +func (p *hstoreParser) Consume() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, w := utf8.DecodeRuneInString(p.str[p.pos:]) + p.pos += w + return +} + +func (p *hstoreParser) Peek() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) + return +} + +// parseHstore parses the string representation of an hstore column (the same +// you would get from an ordinary SELECT) into two slices of keys and values. it +// is used internally in the default parsing of hstores. +func parseHstore(s string) (k []string, v []Text, err error) { + if s == "" { + return + } + + buf := bytes.Buffer{} + keys := []string{} + values := []Text{} + p := newHSP(s) + + r, end := p.Consume() + state := hsPre + + for !end { + switch state { + case hsPre: + if r == '"' { + state = hsKey + } else { + err = errors.New("String does not begin with \"") + } + case hsKey: + switch r { + case '"': //End of the key + keys = append(keys, buf.String()) + buf = bytes.Buffer{} + state = hsSep + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsSep: + if r == '=' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=', expecting '>'") + case r == '>': + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") + case r == '"': + state = hsVal + case r == 'N': + state = hsNul + default: + err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) + } + default: + err = fmt.Errorf("Invalid character after '=', expecting '>'") + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) + } + case hsVal: + switch r { + case '"': //End of the value + values = append(values, Text{String: buf.String(), Valid: true}) + buf = bytes.Buffer{} + state = hsNext + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsNul: + nulBuf := make([]rune, 3) + nulBuf[0] = r + for i := 1; i < 3; i++ { + r, end = p.Consume() + if end { + err = errors.New("Found EOS in NULL value") + return + } + nulBuf[i] = r + } + if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { + values = append(values, Text{}) + state = hsNext + } else { + err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + } + case hsNext: + if r == ',' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after ',', expcting space") + case (unicode.IsSpace(r)): + r, end = p.Consume() + state = hsKey + default: + err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) + } + } + + if err != nil { + return + } + r, end = p.Consume() + } + if state != hsNext { + err = errors.New("Improperly formatted hstore") + return + } + k = keys + v = values + return +} + +// Scan implements the database/sql Scanner interface. +func (dst *Hstore) Scan(src interface{}) error { + if src == nil { + *dst = Hstore{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Hstore) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go new file mode 100644 index 00000000..0ca5d4fb --- /dev/null +++ b/pgtype/hstore_array.go @@ -0,0 +1,476 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type HstoreArray struct { + Elements []Hstore + Dimensions []ArrayDimension + Valid bool +} + +func (dst *HstoreArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = HstoreArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []map[string]string: + if value == nil { + *dst = HstoreArray{} + } else if len(value) == 0 { + *dst = HstoreArray{Valid: true} + } else { + elements := make([]Hstore, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = HstoreArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Hstore: + if value == nil { + *dst = HstoreArray{} + } else if len(value) == 0 { + *dst = HstoreArray{Valid: true} + } else { + *dst = HstoreArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = HstoreArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for HstoreArray", src) + } + if elementsLength == 0 { + *dst = HstoreArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to HstoreArray", src) + } + + *dst = HstoreArray{ + Elements: make([]Hstore, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Hstore, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to HstoreArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *HstoreArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to HstoreArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in HstoreArray", err) + } + index++ + + return index, nil +} + +func (dst HstoreArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *HstoreArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]map[string]string: + *v = make([]map[string]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from HstoreArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from HstoreArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = HstoreArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Hstore + + if len(uta.Elements) > 0 { + elements = make([]Hstore, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Hstore + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = HstoreArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = HstoreArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = HstoreArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Hstore, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = HstoreArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("hstore"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "hstore") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *HstoreArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src HstoreArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/hstore_array_test.go b/pgtype/hstore_array_test.go new file mode 100644 index 00000000..11290fb1 --- /dev/null +++ b/pgtype/hstore_array_test.go @@ -0,0 +1,436 @@ +package pgtype_test + +import ( + "context" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" +) + +func TestHstoreArrayTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + var hstoreOID uint32 + err := conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='hstore';").Scan(&hstoreOID) + if err != nil { + t.Fatalf("did not find hstore OID, %v", err) + } + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Hstore{}, Name: "hstore", OID: hstoreOID}) + + var hstoreArrayOID uint32 + err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='_hstore';").Scan(&hstoreArrayOID) + if err != nil { + t.Fatalf("did not find _hstore OID, %v", err) + } + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.HstoreArray{}, Name: "_hstore", OID: hstoreArrayOID}) + + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Valid: true} + } + + values := []pgtype.Hstore{ + {Map: map[string]pgtype.Text{}, Valid: true}, + {Map: map[string]pgtype.Text{"foo": text("bar")}, Valid: true}, + {Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Valid: true}, + {Map: map[string]pgtype.Text{"NULL": text("bar")}, Valid: true}, + {Map: map[string]pgtype.Text{"foo": text("NULL")}, Valid: true}, + {}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Valid: true}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Valid: true}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Valid: true}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Valid: true}) // is key + + // Special value values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Valid: true}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Valid: true}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Valid: true}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Valid: true}) // is key + } + + src := &pgtype.HstoreArray{ + Elements: values, + Dimensions: []pgtype.ArrayDimension{{Length: int32(len(values)), LowerBound: 1}}, + Valid: true, + } + + _, err = conn.Prepare(context.Background(), "test", "select $1::hstore[]") + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, fc := range formats { + queryResultFormats := pgx.QueryResultFormats{fc.formatCode} + vEncoder := testutil.ForceEncoder(src, fc.formatCode) + if vEncoder == nil { + t.Logf("%#v does not implement %v", src, fc.name) + continue + } + + var result pgtype.HstoreArray + err := conn.QueryRow(context.Background(), "test", queryResultFormats, vEncoder).Scan(&result) + if err != nil { + t.Errorf("%v: %v", fc.name, err) + continue + } + + if result.Valid != src.Valid { + t.Errorf("%v: expected Valid %v, got %v", fc.formatCode, src.Valid, result.Valid) + continue + } + + if len(result.Elements) != len(src.Elements) { + t.Errorf("%v: expected %v elements, got %v", fc.formatCode, len(src.Elements), len(result.Elements)) + continue + } + + for i := range result.Elements { + a := src.Elements[i] + b := result.Elements[i] + + if a.Valid != b.Valid { + t.Errorf("%v element idx %d: expected Valid %v, got %v", fc.formatCode, i, a.Valid, b.Valid) + } + + if len(a.Map) != len(b.Map) { + t.Errorf("%v element idx %d: expected %v pairs, got %v", fc.formatCode, i, len(a.Map), len(b.Map)) + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + t.Errorf("%v element idx %d: expected key %v to be %v, got %v", fc.formatCode, i, k, a.Map[k], b.Map[k]) + } + } + } + } +} + +func TestHstoreArraySet(t *testing.T) { + successfulTests := []struct { + src interface{} + result pgtype.HstoreArray + }{ + { + src: []map[string]string{{"foo": "bar"}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + }, + { + src: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true, + }, + }, + { + src: [][][][]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, + Valid: true, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true, + }, + }, + { + src: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true, + }, + }, + { + src: [2][1][1][3]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, + Valid: true, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true, + }, + }, + } + + for i, tt := range successfulTests { + var dst pgtype.HstoreArray + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + +func TestHstoreArrayAssignTo(t *testing.T) { + var hstoreSlice []map[string]string + var hstoreSliceDim2 [][]map[string]string + var hstoreSliceDim4 [][][][]map[string]string + var hstoreArrayDim2 [2][1]map[string]string + var hstoreArrayDim4 [2][1][1][3]map[string]string + + simpleTests := []struct { + src pgtype.HstoreArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &hstoreSlice, + expected: []map[string]string{{"foo": "bar"}}}, + { + src: pgtype.HstoreArray{}, dst: &hstoreSlice, expected: (([]map[string]string)(nil)), + }, + { + src: pgtype.HstoreArray{Valid: true}, dst: &hstoreSlice, expected: []map[string]string{}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &hstoreSliceDim2, + expected: [][]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, + Valid: true, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true, + }, + dst: &hstoreSliceDim4, + expected: [][][][]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &hstoreArrayDim2, + expected: [2][1]map[string]string{{{"foo": "bar"}}, {{"baz": "quz"}}}, + }, + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"baz": {String: "quz", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"bar": {String: "baz", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"wibble": {String: "wobble", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"wubble": {String: "wabble", Valid: true}}, + Valid: true, + }, + { + Map: map[string]pgtype.Text{"wabble": {String: "wobble", Valid: true}}, + Valid: true, + }, + }, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true, + }, + dst: &hstoreArrayDim4, + expected: [2][1][1][3]map[string]string{ + {{{{"foo": "bar"}, {"baz": "quz"}, {"bar": "baz"}}}}, + {{{{"wibble": "wobble"}, {"wubble": "wabble"}, {"wabble": "wobble"}}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go new file mode 100644 index 00000000..9c26a3df --- /dev/null +++ b/pgtype/hstore_test.go @@ -0,0 +1,204 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestHstoreTranscode(t *testing.T) { + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Valid: true} + } + + values := []interface{}{ + &pgtype.Hstore{Map: map[string]pgtype.Text{}, Valid: true}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(""), "bar": text(""), "baz": text("123")}, Valid: true}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Valid: true}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Valid: true}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Valid: true}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Valid: true}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"": text("bar")}, Valid: true}, + &pgtype.Hstore{ + Map: map[string]pgtype.Text{"a": text("a"), "b": {}, "c": text("c"), "d": {}, "e": text("e")}, + Valid: true, + }, + &pgtype.Hstore{}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Valid: true}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Valid: true}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Valid: true}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Valid: true}) // is key + + // Special value values + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Valid: true}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Valid: true}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Valid: true}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Valid: true}) // is key + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { + a := ai.(pgtype.Hstore) + b := bi.(pgtype.Hstore) + + if len(a.Map) != len(b.Map) || a.Valid != b.Valid { + return false + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + return false + } + } + + return true + }) +} + +func TestHstoreTranscodeNullable(t *testing.T) { + text := func(s string, valid bool) pgtype.Text { + return pgtype.Text{String: s, Valid: valid} + } + + values := []interface{}{ + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("", false)}, Valid: true}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("", false)}, Valid: true}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("", false)}, Valid: true}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("", false)}, Valid: true}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("", false)}, Valid: true}) // is key + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { + a := ai.(pgtype.Hstore) + b := bi.(pgtype.Hstore) + + if len(a.Map) != len(b.Map) || a.Valid != b.Valid { + return false + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + return false + } + } + + return true + }) +} + +func TestHstoreSet(t *testing.T) { + successfulTests := []struct { + src map[string]string + result pgtype.Hstore + }{ + {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, Valid: true}}, + } + + for i, tt := range successfulTests { + var dst pgtype.Hstore + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + +func TestHstoreSetNullable(t *testing.T) { + successfulTests := []struct { + src map[string]*string + result pgtype.Hstore + }{ + {src: map[string]*string{"foo": nil}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {}}, Valid: true}}, + } + + for i, tt := range successfulTests { + var dst pgtype.Hstore + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + +func TestHstoreAssignTo(t *testing.T) { + var m map[string]string + + simpleTests := []struct { + src pgtype.Hstore + dst *map[string]string + expected map[string]string + }{ + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {String: "bar", Valid: true}}, Valid: true}, dst: &m, expected: map[string]string{"foo": "bar"}}, + {src: pgtype.Hstore{}, dst: &m, expected: ((map[string]string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(*tt.dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} + +func TestHstoreAssignToNullable(t *testing.T) { + var m map[string]*string + + simpleTests := []struct { + src pgtype.Hstore + dst *map[string]*string + expected map[string]*string + }{ + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": {}}, Valid: true}, dst: &m, expected: map[string]*string{"foo": nil}}, + {src: pgtype.Hstore{}, dst: &m, expected: ((map[string]*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(*tt.dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/pgtype/inet.go b/pgtype/inet.go new file mode 100644 index 00000000..4b3217a9 --- /dev/null +++ b/pgtype/inet.go @@ -0,0 +1,245 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + "net" +) + +// Network address family is dependent on server socket.h value for AF_INET. +// In practice, all platforms appear to have the same value. See +// src/include/utils/inet.h for more information. +const ( + defaultAFInet = 2 + defaultAFInet6 = 3 +) + +// Inet represents both inet and cidr PostgreSQL types. +type Inet struct { + IPNet *net.IPNet + Valid bool +} + +func (dst *Inet) Set(src interface{}) error { + if src == nil { + *dst = Inet{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case net.IPNet: + *dst = Inet{IPNet: &value, Valid: true} + case net.IP: + if len(value) == 0 { + *dst = Inet{} + } else { + bitCount := len(value) * 8 + mask := net.CIDRMask(bitCount, bitCount) + *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Valid: true} + } + case string: + ip, ipnet, err := net.ParseCIDR(value) + if err != nil { + ip = net.ParseIP(value) + if ip == nil { + return fmt.Errorf("unable to parse inet address: %s", value) + } + ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + ipnet.Mask = net.CIDRMask(32, 32) + } + } + ipnet.IP = ip + *dst = Inet{IPNet: ipnet, Valid: true} + case *net.IPNet: + if value == nil { + *dst = Inet{} + } else { + return dst.Set(*value) + } + case *net.IP: + if value == nil { + *dst = Inet{} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Inet{} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Inet", value) + } + + return nil +} + +func (dst Inet) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.IPNet +} + +func (src *Inet) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *net.IPNet: + *v = net.IPNet{ + IP: make(net.IP, len(src.IPNet.IP)), + Mask: make(net.IPMask, len(src.IPNet.Mask)), + } + copy(v.IP, src.IPNet.IP) + copy(v.Mask, src.IPNet.Mask) + return nil + case *net.IP: + if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = make(net.IP, len(src.IPNet.IP)) + copy(*v, src.IPNet.IP) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Inet{} + return nil + } + + var ipnet *net.IPNet + var err error + + if ip := net.ParseIP(string(src)); ip != nil { + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + bitCount := len(ip) * 8 + mask := net.CIDRMask(bitCount, bitCount) + ipnet = &net.IPNet{Mask: mask, IP: ip} + } else { + ip, ipnet, err = net.ParseCIDR(string(src)) + if err != nil { + return err + } + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + ones, _ := ipnet.Mask.Size() + *ipnet = net.IPNet{IP: ip, Mask: net.CIDRMask(ones, len(ip)*8)} + } + + *dst = Inet{IPNet: ipnet, Valid: true} + return nil +} + +func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Inet{} + return nil + } + + if len(src) != 8 && len(src) != 20 { + return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) + } + + // ignore family + bits := src[1] + // ignore is_cidr + addressLength := src[3] + + var ipnet net.IPNet + ipnet.IP = make(net.IP, int(addressLength)) + copy(ipnet.IP, src[4:]) + if ipv4 := ipnet.IP.To4(); ipv4 != nil { + ipnet.IP = ipv4 + } + ipnet.Mask = net.CIDRMask(int(bits), len(ipnet.IP)*8) + + *dst = Inet{IPNet: &ipnet, Valid: true} + + return nil +} + +func (src Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, src.IPNet.String()...), nil +} + +// EncodeBinary encodes src into w. +func (src Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var family byte + switch len(src.IPNet.IP) { + case net.IPv4len: + family = defaultAFInet + case net.IPv6len: + family = defaultAFInet6 + default: + return nil, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) + } + + buf = append(buf, family) + + ones, _ := src.IPNet.Mask.Size() + buf = append(buf, byte(ones)) + + // is_cidr is ignored on server + buf = append(buf, 0) + + buf = append(buf, byte(len(src.IPNet.IP))) + + return append(buf, src.IPNet.IP...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Inet) Scan(src interface{}) error { + if src == nil { + *dst = Inet{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Inet) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go new file mode 100644 index 00000000..7f41c4e5 --- /dev/null +++ b/pgtype/inet_array.go @@ -0,0 +1,533 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "net" + "reflect" + + "github.com/jackc/pgio" +) + +type InetArray struct { + Elements []Inet + Dimensions []ArrayDimension + Valid bool +} + +func (dst *InetArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = InetArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []*net.IPNet: + if value == nil { + *dst = InetArray{} + } else if len(value) == 0 { + *dst = InetArray{Valid: true} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []net.IP: + if value == nil { + *dst = InetArray{} + } else if len(value) == 0 { + *dst = InetArray{Valid: true} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*net.IP: + if value == nil { + *dst = InetArray{} + } else if len(value) == 0 { + *dst = InetArray{Valid: true} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Inet: + if value == nil { + *dst = InetArray{} + } else if len(value) == 0 { + *dst = InetArray{Valid: true} + } else { + *dst = InetArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = InetArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for InetArray", src) + } + if elementsLength == 0 { + *dst = InetArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to InetArray", src) + } + + *dst = InetArray{ + Elements: make([]Inet, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Inet, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to InetArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *InetArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to InetArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in InetArray", err) + } + index++ + + return index, nil +} + +func (dst InetArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *InetArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]*net.IPNet: + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]net.IP: + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*net.IP: + *v = make([]*net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from InetArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from InetArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = InetArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Inet + + if len(uta.Elements) > 0 { + elements = make([]Inet, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Inet + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = InetArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = InetArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = InetArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Inet, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = InetArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("inet"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "inet") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *InetArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src InetArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/inet_array_test.go b/pgtype/inet_array_test.go new file mode 100644 index 00000000..1019c7eb --- /dev/null +++ b/pgtype/inet_array_test.go @@ -0,0 +1,319 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInetArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "inet[]", []interface{}{ + &pgtype.InetArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.InetArray{}, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, + {}, + {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestInetArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.InetArray + }{ + { + source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.InetArray{}, + }, + { + source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.InetArray{}, + }, + { + source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.InetArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInetArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + var ipSliceDim2 [][]net.IP + var ipnetSliceDim4 [][][][]*net.IPNet + var ipArrayDim2 [2][1]net.IP + var ipnetArrayDim4 [2][1][1][3]*net.IPNet + + simpleTests := []struct { + src pgtype.InetArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.InetArray{}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.InetArray{Valid: true}, + dst: &ipnetSlice, + expected: []*net.IPNet{}, + }, + { + src: pgtype.InetArray{}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + { + src: pgtype.InetArray{Valid: true}, + dst: &ipSlice, + expected: []net.IP{}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &ipSliceDim2, + expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &ipnetSliceDim4, + expected: [][][][]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &ipArrayDim2, + expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{ + {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, + {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, + {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &ipnetArrayDim4, + expected: [2][1][1][3]*net.IPNet{ + {{{ + mustParseCIDR(t, "127.0.0.1/24"), + mustParseCIDR(t, "10.0.0.1/24"), + mustParseCIDR(t, "172.16.0.1/16")}}}, + {{{ + mustParseCIDR(t, "192.168.0.1/16"), + mustParseCIDR(t, "224.0.0.1/24"), + mustParseCIDR(t, "169.168.0.1/16")}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go new file mode 100644 index 00000000..c2a5dc28 --- /dev/null +++ b/pgtype/inet_test.go @@ -0,0 +1,139 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/assert" +) + +func TestInetTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "inet", []interface{}{ + &pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "127.0.0.1/8"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "12.34.56.65/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "192.168.1.16/24"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "255.0.0.0/8"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "255.255.255.255/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "::1/64"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "::/0"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "::1/128"), Valid: true}, + &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), Valid: true}, + &pgtype.Inet{}, + }) +} + +func TestCidrTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "cidr", []interface{}{ + &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Valid: true}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, + &pgtype.Inet{}, + }) +} + +func TestInetSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Inet + }{ + {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4"), Mask: net.CIDRMask(24, 32)}, Valid: true}}, + {source: "10.0.0.1", result: pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Valid: true}}, + {source: "2607:f8b0:4009:80b::200e", result: pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Valid: true}}, + {source: net.ParseIP(""), result: pgtype.Inet{}}, + } + + for i, tt := range successfulTests { + var r pgtype.Inet + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + assert.Equalf(t, tt.result.Valid, r.Valid, "%d: Status", i) + if tt.result.Valid { + assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: IP", i) + assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: Mask", i) + } + } +} + +func TestInetAssignTo(t *testing.T) { + var ipnet net.IPNet + var pipnet *net.IPNet + var ip net.IP + var pip *net.IP + + simpleTests := []struct { + src pgtype.Inet + dst interface{} + expected interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, + {src: pgtype.Inet{}, dst: &pip, expected: ((*net.IP)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %#v, but result was %#v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Inet + dst interface{} + expected interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Inet + dst interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Valid: true}, dst: &ip}, + {src: pgtype.Inet{}, dst: &ipnet}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/int2.go b/pgtype/int2.go new file mode 100644 index 00000000..bbfee1cf --- /dev/null +++ b/pgtype/int2.go @@ -0,0 +1,284 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Int2 struct { + Int int16 + Valid bool +} + +func (dst *Int2) Set(src interface{}) error { + if src == nil { + *dst = Int2{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case int8: + *dst = Int2{Int: int16(value), Valid: true} + case uint8: + *dst = Int2{Int: int16(value), Valid: true} + case int16: + *dst = Int2{Int: int16(value), Valid: true} + case uint16: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Valid: true} + case int32: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Valid: true} + case uint32: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Valid: true} + case int64: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Valid: true} + case uint64: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Valid: true} + case int: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Valid: true} + case uint: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Valid: true} + case string: + num, err := strconv.ParseInt(value, 10, 16) + if err != nil { + return err + } + *dst = Int2{Int: int16(num), Valid: true} + case float32: + if value > math.MaxInt16 { + return fmt.Errorf("%f is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Valid: true} + case float64: + if value > math.MaxInt16 { + return fmt.Errorf("%f is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Valid: true} + case *int8: + if value == nil { + *dst = Int2{} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Int2{} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Int2{} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Int2{} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Int2{} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Int2{} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Int2{} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Int2{} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Int2{} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Int2{} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Int2{} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Int2{} + } else { + return dst.Set(*value) + } + case *float64: + if value == nil { + *dst = Int2{} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int2", value) + } + + return nil +} + +func (dst Int2) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.Int +} + +func (src *Int2) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Valid, dst) +} + +func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int2{} + return nil + } + + n, err := strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + + *dst = Int2{Int: int16(n), Valid: true} + return nil +} + +func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int2{} + return nil + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + n := int16(binary.BigEndian.Uint16(src)) + *dst = Int2{Int: n, Valid: true} + return nil +} + +func (src Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil +} + +func (src Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return pgio.AppendInt16(buf, src.Int), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src interface{}) error { + if src == nil { + *dst = Int2{} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", src) + } + if src > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", src) + } + *dst = Int2{Int: int16(src), Valid: true} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int), nil +} + +func (src Int2) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil +} diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go new file mode 100644 index 00000000..d96240dc --- /dev/null +++ b/pgtype/int2_array.go @@ -0,0 +1,896 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type Int2Array struct { + Elements []Int2 + Dimensions []ArrayDimension + Valid bool +} + +func (dst *Int2Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int2Array{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []int16: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*int16: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []uint16: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*uint16: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []int32: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*int32: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []uint32: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*uint32: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []int64: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*int64: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []uint64: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*uint64: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []int: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*int: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []uint: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*uint: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Int2: + if value == nil { + *dst = Int2Array{} + } else if len(value) == 0 { + *dst = Int2Array{Valid: true} + } else { + *dst = Int2Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Int2Array{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for Int2Array", src) + } + if elementsLength == 0 { + *dst = Int2Array{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int2Array", src) + } + + *dst = Int2Array{ + Elements: make([]Int2, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int2, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to Int2Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *Int2Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to Int2Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in Int2Array", err) + } + index++ + + return index, nil +} + +func (dst Int2Array) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Int2Array) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from Int2Array") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from Int2Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int2Array{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Int2 + + if len(uta.Elements) > 0 { + elements = make([]Int2, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int2 + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int2Array{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Int2Array{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int2, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("int2"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "int2") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int2Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/int2_array_test.go b/pgtype/int2_array_test.go new file mode 100644 index 00000000..78dc532a --- /dev/null +++ b/pgtype/int2_array_test.go @@ -0,0 +1,342 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt2ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int2[]", []interface{}{ + &pgtype.Int2Array{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.Int2Array{}, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {}, + {Int: 6, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestInt2ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int2Array + }{ + { + source: []int64{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []int32{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []int16{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []int{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []uint64{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []uint32{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []uint16{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]int16)(nil)), + result: pgtype.Int2Array{}, + }, + { + source: [][]int16{{1}, {2}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]int16{{1}, {2}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Int2Array + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt2ArrayAssignTo(t *testing.T) { + var int16Slice []int16 + var uint16Slice []uint16 + var namedInt16Slice _int16Slice + var int16SliceDim2 [][]int16 + var int16SliceDim4 [][][][]int16 + var int16ArrayDim2 [2][1]int16 + var int16ArrayDim4 [2][1][1][3]int16 + + simpleTests := []struct { + src pgtype.Int2Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &int16Slice, + expected: []int16{1}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &uint16Slice, + expected: []uint16{1}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &namedInt16Slice, + expected: _int16Slice{1}, + }, + { + src: pgtype.Int2Array{}, + dst: &int16Slice, + expected: (([]int16)(nil)), + }, + { + src: pgtype.Int2Array{Valid: true}, + dst: &int16Slice, + expected: []int16{}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + expected: [][]int16{{1}, {2}}, + dst: &int16SliceDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + expected: [][][][]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int16SliceDim4, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + expected: [2][1]int16{{1}, {2}}, + dst: &int16ArrayDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + expected: [2][1][1][3]int16{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int16ArrayDim4, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int2Array + dst interface{} + }{ + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &int16Slice, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: -1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &uint16Slice, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &int16ArrayDim2, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &int16Slice, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &int16ArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go new file mode 100644 index 00000000..6ed8fe90 --- /dev/null +++ b/pgtype/int2_test.go @@ -0,0 +1,144 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt2Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ + &pgtype.Int2{Int: math.MinInt16, Valid: true}, + &pgtype.Int2{Int: -1, Valid: true}, + &pgtype.Int2{Int: 0, Valid: true}, + &pgtype.Int2{Int: 1, Valid: true}, + &pgtype.Int2{Int: math.MaxInt16, Valid: true}, + &pgtype.Int2{Int: 0}, + }) +} + +func TestInt2Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int2 + }{ + {source: int8(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: int16(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: int32(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: int64(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: int8(-1), result: pgtype.Int2{Int: -1, Valid: true}}, + {source: int16(-1), result: pgtype.Int2{Int: -1, Valid: true}}, + {source: int32(-1), result: pgtype.Int2{Int: -1, Valid: true}}, + {source: int64(-1), result: pgtype.Int2{Int: -1, Valid: true}}, + {source: uint8(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: uint16(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: uint32(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: uint64(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: float32(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: float64(1), result: pgtype.Int2{Int: 1, Valid: true}}, + {source: "1", result: pgtype.Int2{Int: 1, Valid: true}}, + {source: _int8(1), result: pgtype.Int2{Int: 1, Valid: true}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int2 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt2AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int2 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &i, expected: int(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int2{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int2{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int2 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int2{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int2 + dst interface{} + }{ + {src: pgtype.Int2{Int: 150, Valid: true}, dst: &i8}, + {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui8}, + {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui16}, + {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui32}, + {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui64}, + {src: pgtype.Int2{Int: -1, Valid: true}, dst: &ui}, + {src: pgtype.Int2{Int: 0}, dst: &i16}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/int4.go b/pgtype/int4.go new file mode 100644 index 00000000..6f1e61f3 --- /dev/null +++ b/pgtype/int4.go @@ -0,0 +1,292 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Int4 struct { + Int int32 + Valid bool +} + +func (dst *Int4) Set(src interface{}) error { + if src == nil { + *dst = Int4{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case int8: + *dst = Int4{Int: int32(value), Valid: true} + case uint8: + *dst = Int4{Int: int32(value), Valid: true} + case int16: + *dst = Int4{Int: int32(value), Valid: true} + case uint16: + *dst = Int4{Int: int32(value), Valid: true} + case int32: + *dst = Int4{Int: int32(value), Valid: true} + case uint32: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Valid: true} + case int64: + if value < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Valid: true} + case uint64: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Valid: true} + case int: + if value < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Valid: true} + case uint: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Valid: true} + case string: + num, err := strconv.ParseInt(value, 10, 32) + if err != nil { + return err + } + *dst = Int4{Int: int32(num), Valid: true} + case float32: + if value > math.MaxInt32 { + return fmt.Errorf("%f is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Valid: true} + case float64: + if value > math.MaxInt32 { + return fmt.Errorf("%f is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Valid: true} + case *int8: + if value == nil { + *dst = Int4{} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Int4{} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Int4{} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Int4{} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Int4{} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Int4{} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Int4{} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Int4{} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Int4{} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Int4{} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Int4{} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Int4{} + } else { + return dst.Set(*value) + } + case *float64: + if value == nil { + *dst = Int4{} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int4", value) + } + + return nil +} + +func (dst Int4) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.Int +} + +func (src *Int4) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Valid, dst) +} + +func (dst *Int4) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4{} + return nil + } + + n, err := strconv.ParseInt(string(src), 10, 32) + if err != nil { + return err + } + + *dst = Int4{Int: int32(n), Valid: true} + return nil +} + +func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4{} + return nil + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + n := int32(binary.BigEndian.Uint32(src)) + *dst = Int4{Int: n, Valid: true} + return nil +} + +func (src Int4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil +} + +func (src Int4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return pgio.AppendInt32(buf, src.Int), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src interface{}) error { + if src == nil { + *dst = Int4{} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", src) + } + if src > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", src) + } + *dst = Int4{Int: int32(src), Valid: true} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int), nil +} + +func (src Int4) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil +} + +func (dst *Int4) UnmarshalJSON(b []byte) error { + var n *int32 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int4{} + } else { + *dst = Int4{Int: *n, Valid: true} + } + + return nil +} diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go new file mode 100644 index 00000000..e725e7a8 --- /dev/null +++ b/pgtype/int4_array.go @@ -0,0 +1,896 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type Int4Array struct { + Elements []Int4 + Dimensions []ArrayDimension + Valid bool +} + +func (dst *Int4Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int4Array{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []int16: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*int16: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []uint16: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*uint16: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []int32: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*int32: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []uint32: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*uint32: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []int64: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*int64: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []uint64: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*uint64: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []int: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*int: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []uint: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*uint: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Int4: + if value == nil { + *dst = Int4Array{} + } else if len(value) == 0 { + *dst = Int4Array{Valid: true} + } else { + *dst = Int4Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Int4Array{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for Int4Array", src) + } + if elementsLength == 0 { + *dst = Int4Array{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int4Array", src) + } + + *dst = Int4Array{ + Elements: make([]Int4, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int4, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to Int4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *Int4Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to Int4Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in Int4Array", err) + } + index++ + + return index, nil +} + +func (dst Int4Array) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Int4Array) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from Int4Array") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from Int4Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4Array{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Int4 + + if len(uta.Elements) > 0 { + elements = make([]Int4, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int4 + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int4Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4Array{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Int4Array{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int4, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Int4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("int4"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "int4") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/int4_array_test.go b/pgtype/int4_array_test.go new file mode 100644 index 00000000..a9c9acd9 --- /dev/null +++ b/pgtype/int4_array_test.go @@ -0,0 +1,356 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt4ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4[]", []interface{}{ + &pgtype.Int4Array{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.Int4Array{}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {}, + {Int: 6, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestInt4ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int4Array + expectedError bool + }{ + { + source: []int64{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []int32{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []int16{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []int{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []int{1, math.MaxInt32 + 1, 2}, + expectedError: true, + }, + { + source: []uint64{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []uint32{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []uint16{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]int32)(nil)), + result: pgtype.Int4Array{}, + }, + { + source: [][]int32{{1}, {2}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]int32{{1}, {2}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Int4Array + err := r.Set(tt.source) + if err != nil { + if tt.expectedError { + continue + } + t.Errorf("%d: %v", i, err) + } + + if tt.expectedError { + t.Errorf("%d: an error was expected, %v", i, tt) + continue + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt4ArrayAssignTo(t *testing.T) { + var int32Slice []int32 + var uint32Slice []uint32 + var namedInt32Slice _int32Slice + var int32SliceDim2 [][]int32 + var int32SliceDim4 [][][][]int32 + var int32ArrayDim2 [2][1]int32 + var int32ArrayDim4 [2][1][1][3]int32 + + simpleTests := []struct { + src pgtype.Int4Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &int32Slice, + expected: []int32{1}, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &uint32Slice, + expected: []uint32{1}, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &namedInt32Slice, + expected: _int32Slice{1}, + }, + { + src: pgtype.Int4Array{}, + dst: &int32Slice, + expected: (([]int32)(nil)), + }, + { + src: pgtype.Int4Array{Valid: true}, + dst: &int32Slice, + expected: []int32{}, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + expected: [][]int32{{1}, {2}}, + dst: &int32SliceDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + expected: [][][][]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int32SliceDim4, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + expected: [2][1]int32{{1}, {2}}, + dst: &int32ArrayDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + expected: [2][1][1][3]int32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int32ArrayDim4, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int4Array + dst interface{} + }{ + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &int32Slice, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: -1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &uint32Slice, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &int32ArrayDim2, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &int32Slice, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &int32ArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/int4_test.go b/pgtype/int4_test.go new file mode 100644 index 00000000..3085babd --- /dev/null +++ b/pgtype/int4_test.go @@ -0,0 +1,186 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt4Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ + &pgtype.Int4{Int: math.MinInt32, Valid: true}, + &pgtype.Int4{Int: -1, Valid: true}, + &pgtype.Int4{Int: 0, Valid: true}, + &pgtype.Int4{Int: 1, Valid: true}, + &pgtype.Int4{Int: math.MaxInt32, Valid: true}, + &pgtype.Int4{Int: 0}, + }) +} + +func TestInt4Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int4 + }{ + {source: int8(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: int16(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: int32(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: int64(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: int8(-1), result: pgtype.Int4{Int: -1, Valid: true}}, + {source: int16(-1), result: pgtype.Int4{Int: -1, Valid: true}}, + {source: int32(-1), result: pgtype.Int4{Int: -1, Valid: true}}, + {source: int64(-1), result: pgtype.Int4{Int: -1, Valid: true}}, + {source: uint8(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: uint16(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: uint32(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: uint64(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: float32(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: float64(1), result: pgtype.Int4{Int: 1, Valid: true}}, + {source: "1", result: pgtype.Int4{Int: 1, Valid: true}}, + {source: _int8(1), result: pgtype.Int4{Int: 1, Valid: true}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int4 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt4AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &i, expected: int(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int4{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int4{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int4{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int4 + dst interface{} + }{ + {src: pgtype.Int4{Int: 150, Valid: true}, dst: &i8}, + {src: pgtype.Int4{Int: 40000, Valid: true}, dst: &i16}, + {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui8}, + {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui16}, + {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui32}, + {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui64}, + {src: pgtype.Int4{Int: -1, Valid: true}, dst: &ui}, + {src: pgtype.Int4{Int: 0}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} + +func TestInt4MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int4 + result string + }{ + {source: pgtype.Int4{Int: 0}, result: "null"}, + {source: pgtype.Int4{Int: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt4UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int4 + }{ + {source: "null", result: pgtype.Int4{Int: 0}}, + {source: "1", result: pgtype.Int4{Int: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int4 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/int4range.go b/pgtype/int4range.go new file mode 100644 index 00000000..49503c0d --- /dev/null +++ b/pgtype/int4range.go @@ -0,0 +1,257 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +type Int4range struct { + Lower Int4 + Upper Int4 + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (dst *Int4range) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int4range{} + return nil + } + + switch value := src.(type) { + case Int4range: + *dst = value + case *Int4range: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to Int4range", src) + } + + return nil +} + +func (src Int4range) Get() interface{} { + if !src.Valid { + return nil + } + return src +} + +func (src *Int4range) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Int4range) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4range{} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Int4range{Valid: true} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4range{} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Int4range{Valid: true} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4range) Scan(src interface{}) error { + if src == nil { + *dst = Int4range{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4range) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/int4range_test.go b/pgtype/int4range_test.go new file mode 100644 index 00000000..8b990036 --- /dev/null +++ b/pgtype/int4range_test.go @@ -0,0 +1,28 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt4rangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4range", []interface{}{ + &pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Valid: true}, Upper: pgtype.Int4{Int: 10, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Valid: true}, Upper: pgtype.Int4{Int: -5, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Valid: true}, + &pgtype.Int4range{Upper: pgtype.Int4{Int: 1, Valid: true}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Valid: true}, + &pgtype.Int4range{}, + }) +} + +func TestInt4rangeNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select int4range(1, 10, '(]')", + Value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Valid: true}, Upper: pgtype.Int4{Int: 11, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, + }, + }) +} diff --git a/pgtype/int8.go b/pgtype/int8.go new file mode 100644 index 00000000..794f92c6 --- /dev/null +++ b/pgtype/int8.go @@ -0,0 +1,278 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Int8 struct { + Int int64 + Valid bool +} + +func (dst *Int8) Set(src interface{}) error { + if src == nil { + *dst = Int8{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case int8: + *dst = Int8{Int: int64(value), Valid: true} + case uint8: + *dst = Int8{Int: int64(value), Valid: true} + case int16: + *dst = Int8{Int: int64(value), Valid: true} + case uint16: + *dst = Int8{Int: int64(value), Valid: true} + case int32: + *dst = Int8{Int: int64(value), Valid: true} + case uint32: + *dst = Int8{Int: int64(value), Valid: true} + case int64: + *dst = Int8{Int: int64(value), Valid: true} + case uint64: + if value > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Valid: true} + case int: + if int64(value) < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + if int64(value) > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Valid: true} + case uint: + if uint64(value) > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Valid: true} + case string: + num, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + *dst = Int8{Int: num, Valid: true} + case float32: + if value > math.MaxInt64 { + return fmt.Errorf("%f is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Valid: true} + case float64: + if value > math.MaxInt64 { + return fmt.Errorf("%f is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Valid: true} + case *int8: + if value == nil { + *dst = Int8{} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Int8{} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Int8{} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Int8{} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Int8{} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Int8{} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Int8{} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Int8{} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Int8{} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Int8{} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Int8{} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Int8{} + } else { + return dst.Set(*value) + } + case *float64: + if value == nil { + *dst = Int8{} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (dst Int8) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.Int +} + +func (src *Int8) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Valid, dst) +} + +func (dst *Int8) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8{} + return nil + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + *dst = Int8{Int: n, Valid: true} + return nil +} + +func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8{} + return nil + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + + *dst = Int8{Int: n, Valid: true} + return nil +} + +func (src Int8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, strconv.FormatInt(src.Int, 10)...), nil +} + +func (src Int8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return pgio.AppendInt64(buf, src.Int), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src interface{}) error { + if src == nil { + *dst = Int8{} + return nil + } + + switch src := src.(type) { + case int64: + *dst = Int8{Int: src, Valid: true} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Int), nil +} + +func (src Int8) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(src.Int, 10)), nil +} + +func (dst *Int8) UnmarshalJSON(b []byte) error { + var n *int64 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int8{} + } else { + *dst = Int8{Int: *n, Valid: true} + } + + return nil +} diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go new file mode 100644 index 00000000..d6f38994 --- /dev/null +++ b/pgtype/int8_array.go @@ -0,0 +1,896 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type Int8Array struct { + Elements []Int8 + Dimensions []ArrayDimension + Valid bool +} + +func (dst *Int8Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int8Array{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []int16: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*int16: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []uint16: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*uint16: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []int32: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*int32: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []uint32: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*uint32: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []int64: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*int64: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []uint64: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*uint64: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []int: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*int: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []uint: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*uint: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Int8: + if value == nil { + *dst = Int8Array{} + } else if len(value) == 0 { + *dst = Int8Array{Valid: true} + } else { + *dst = Int8Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Int8Array{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for Int8Array", src) + } + if elementsLength == 0 { + *dst = Int8Array{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8Array", src) + } + + *dst = Int8Array{ + Elements: make([]Int8, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int8, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to Int8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *Int8Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to Int8Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in Int8Array", err) + } + index++ + + return index, nil +} + +func (dst Int8Array) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Int8Array) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from Int8Array") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from Int8Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8Array{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Int8 + + if len(uta.Elements) > 0 { + elements = make([]Int8, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int8 + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int8Array{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8Array{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Int8Array{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int8, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Int8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("int8"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "int8") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/int8_array_test.go b/pgtype/int8_array_test.go new file mode 100644 index 00000000..29eaf8cb --- /dev/null +++ b/pgtype/int8_array_test.go @@ -0,0 +1,349 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt8ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int8[]", []interface{}{ + &pgtype.Int8Array{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.Int8Array{}, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {}, + {Int: 6, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestInt8ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int8Array + }{ + { + source: []int64{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []int32{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []int16{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []int{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []uint64{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []uint32{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []uint16{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []uint{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]int64)(nil)), + result: pgtype.Int8Array{}, + }, + { + source: [][]int64{{1}, {2}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]int64{{1}, {2}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Int8Array + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt8ArrayAssignTo(t *testing.T) { + var int64Slice []int64 + var uint64Slice []uint64 + var namedInt64Slice _int64Slice + var int64SliceDim2 [][]int64 + var int64SliceDim4 [][][][]int64 + var int64ArrayDim2 [2][1]int64 + var int64ArrayDim4 [2][1][1][3]int64 + + simpleTests := []struct { + src pgtype.Int8Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &int64Slice, + expected: []int64{1}, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &uint64Slice, + expected: []uint64{1}, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &namedInt64Slice, + expected: _int64Slice{1}, + }, + { + src: pgtype.Int8Array{}, + dst: &int64Slice, + expected: (([]int64)(nil)), + }, + { + src: pgtype.Int8Array{Valid: true}, + dst: &int64Slice, + expected: []int64{}, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + expected: [][]int64{{1}, {2}}, + dst: &int64SliceDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + expected: [][][][]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int64SliceDim4, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + expected: [2][1]int64{{1}, {2}}, + dst: &int64ArrayDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {Int: 3, Valid: true}, + {Int: 4, Valid: true}, + {Int: 5, Valid: true}, + {Int: 6, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + expected: [2][1][1][3]int64{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + dst: &int64ArrayDim4, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int8Array + dst interface{} + }{ + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &int64Slice, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: -1, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &uint64Slice, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &int64ArrayDim2, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &int64Slice, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Valid: true}, {Int: 2, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &int64ArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/int8_test.go b/pgtype/int8_test.go new file mode 100644 index 00000000..8aca741d --- /dev/null +++ b/pgtype/int8_test.go @@ -0,0 +1,187 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt8Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ + &pgtype.Int8{Int: math.MinInt64, Valid: true}, + &pgtype.Int8{Int: -1, Valid: true}, + &pgtype.Int8{Int: 0, Valid: true}, + &pgtype.Int8{Int: 1, Valid: true}, + &pgtype.Int8{Int: math.MaxInt64, Valid: true}, + &pgtype.Int8{Int: 0}, + }) +} + +func TestInt8Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int8 + }{ + {source: int8(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: int16(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: int32(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: int64(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: int8(-1), result: pgtype.Int8{Int: -1, Valid: true}}, + {source: int16(-1), result: pgtype.Int8{Int: -1, Valid: true}}, + {source: int32(-1), result: pgtype.Int8{Int: -1, Valid: true}}, + {source: int64(-1), result: pgtype.Int8{Int: -1, Valid: true}}, + {source: uint8(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: uint16(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: uint32(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: uint64(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: float32(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: float64(1), result: pgtype.Int8{Int: 1, Valid: true}}, + {source: "1", result: pgtype.Int8{Int: 1, Valid: true}}, + {source: _int8(1), result: pgtype.Int8{Int: 1, Valid: true}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int8 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt8AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &i, expected: int(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int8{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int8{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int8{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int8 + dst interface{} + }{ + {src: pgtype.Int8{Int: 150, Valid: true}, dst: &i8}, + {src: pgtype.Int8{Int: 40000, Valid: true}, dst: &i16}, + {src: pgtype.Int8{Int: 5000000000, Valid: true}, dst: &i32}, + {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui8}, + {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui16}, + {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui32}, + {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui64}, + {src: pgtype.Int8{Int: -1, Valid: true}, dst: &ui}, + {src: pgtype.Int8{Int: 0}, dst: &i64}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} + +func TestInt8MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int8 + result string + }{ + {source: pgtype.Int8{Int: 0}, result: "null"}, + {source: pgtype.Int8{Int: 1, Valid: true}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestInt8UnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Int8 + }{ + {source: "null", result: pgtype.Int8{Int: 0}}, + {source: "1", result: pgtype.Int8{Int: 1, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Int8 + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/int8range.go b/pgtype/int8range.go new file mode 100644 index 00000000..a7cbcd12 --- /dev/null +++ b/pgtype/int8range.go @@ -0,0 +1,257 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +type Int8range struct { + Lower Int8 + Upper Int8 + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (dst *Int8range) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int8range{} + return nil + } + + switch value := src.(type) { + case Int8range: + *dst = value + case *Int8range: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to Int8range", src) + } + + return nil +} + +func (src Int8range) Get() interface{} { + if !src.Valid { + return nil + } + return src +} + +func (src *Int8range) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8range{} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Int8range{Valid: true} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8range{} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Int8range{Valid: true} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8range) Scan(src interface{}) error { + if src == nil { + *dst = Int8range{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8range) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/int8range_test.go b/pgtype/int8range_test.go new file mode 100644 index 00000000..f2e4098d --- /dev/null +++ b/pgtype/int8range_test.go @@ -0,0 +1,28 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestInt8rangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "Int8range", []interface{}{ + &pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Valid: true}, Upper: pgtype.Int8{Int: 10, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Valid: true}, Upper: pgtype.Int8{Int: -5, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Unbounded, Valid: true}, + &pgtype.Int8range{Upper: pgtype.Int8{Int: 1, Valid: true}, LowerType: pgtype.Unbounded, UpperType: pgtype.Exclusive, Valid: true}, + &pgtype.Int8range{}, + }) +} + +func TestInt8rangeNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select Int8range(1, 10, '(]')", + Value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Valid: true}, Upper: pgtype.Int8{Int: 11, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true}, + }, + }) +} diff --git a/pgtype/integration_benchmark_test.go b/pgtype/integration_benchmark_test.go new file mode 100644 index 00000000..d3af7c31 --- /dev/null +++ b/pgtype/integration_benchmark_test.go @@ -0,0 +1,1292 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" +) + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int16_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int16 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int32_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int32 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_int64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_uint64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]uint64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_int4_to_Go_pgtype_Int4_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Int4 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::int4 + 0, n::int4 + 1, n::int4 + 2, n::int4 + 3, n::int4 + 4, n::int4 + 5, n::int4 + 6, n::int4 + 7, n::int4 + 8, n::int4 + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_int64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]int64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_float64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_float64_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]float64 + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_1_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 1) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_10_rows_1_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [1]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0 from generate_series(1, 10) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryTextFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.TextFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQueryBinaryFormatDecode_PG_numeric_to_Go_pgtype_Numeric_100_rows_10_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [10]pgtype.Numeric + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select n::numeric + 0, n::numeric + 1, n::numeric + 2, n::numeric + 3, n::numeric + 4, n::numeric + 5, n::numeric + 6, n::numeric + 7, n::numeric + 8, n::numeric + 9 from generate_series(1, 100) n`, + []interface{}{pgx.QueryResultFormats{pgx.BinaryFormatCode}}, + []interface{}{&v[0], &v[1], &v[2], &v[3], &v[4], &v[5], &v[6], &v[7], &v[8], &v[9]}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/pgtype/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb new file mode 100644 index 00000000..037c96c3 --- /dev/null +++ b/pgtype/integration_benchmark_test.go.erb @@ -0,0 +1,44 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" +) + +<% + [ + ["int4", ["int16", "int32", "int64", "uint64", "pgtype.Int4"], [[1, 1], [1, 10], [10, 1], [100, 10]]], + ["numeric", ["int64", "float64", "pgtype.Numeric"], [[1, 1], [1, 10], [10, 1], [100, 10]]], + ].each do |pg_type, go_types, rows_columns| +%> +<% go_types.each do |go_type| %> +<% rows_columns.each do |rows, columns| %> +<% [["Text", "pgx.TextFormatCode"], ["Binary", "pgx.BinaryFormatCode"]].each do |formatName, formatCode| %> +func BenchmarkQuery<%= formatName %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go_type.gsub(/\W/, "_") %>_<%= rows %>_rows_<%= columns %>_columns(b *testing.B) { + conn := testutil.MustConnectPgx(b) + defer testutil.MustCloseContext(b, conn) + + b.ResetTimer() + var v [<%= columns %>]<%= go_type %> + for i := 0; i < b.N; i++ { + _, err := conn.QueryFunc( + context.Background(), + `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, + []interface{}{pgx.QueryResultFormats{<%= formatCode %>}}, + []interface{}{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, + func(pgx.QueryFuncRow) error { return nil }, + ) + if err != nil { + b.Fatal(err) + } + } +} +<% end %> +<% end %> +<% end %> +<% end %> diff --git a/pgtype/integration_benchmark_test_gen.sh b/pgtype/integration_benchmark_test_gen.sh new file mode 100755 index 00000000..22ac01aa --- /dev/null +++ b/pgtype/integration_benchmark_test_gen.sh @@ -0,0 +1,2 @@ +erb integration_benchmark_test.go.erb > integration_benchmark_test.go +goimports -w integration_benchmark_test.go diff --git a/pgtype/interval.go b/pgtype/interval.go new file mode 100644 index 00000000..a92cd41f --- /dev/null +++ b/pgtype/interval.go @@ -0,0 +1,244 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "strings" + "time" + + "github.com/jackc/pgio" +) + +const ( + microsecondsPerSecond = 1000000 + microsecondsPerMinute = 60 * microsecondsPerSecond + microsecondsPerHour = 60 * microsecondsPerMinute + microsecondsPerDay = 24 * microsecondsPerHour + microsecondsPerMonth = 30 * microsecondsPerDay +) + +type Interval struct { + Microseconds int64 + Days int32 + Months int32 + Valid bool +} + +func (dst *Interval) Set(src interface{}) error { + if src == nil { + *dst = Interval{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case time.Duration: + *dst = Interval{Microseconds: int64(value) / 1000, Valid: true} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Interval", value) + } + + return nil +} + +func (dst Interval) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Interval) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *time.Duration: + us := int64(src.Months)*microsecondsPerMonth + int64(src.Days)*microsecondsPerDay + src.Microseconds + *v = time.Duration(us) * time.Microsecond + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Interval{} + return nil + } + + var microseconds int64 + var days int32 + var months int32 + + parts := strings.Split(string(src), " ") + + for i := 0; i < len(parts)-1; i += 2 { + scalar, err := strconv.ParseInt(parts[i], 10, 64) + if err != nil { + return fmt.Errorf("bad interval format") + } + + switch parts[i+1] { + case "year", "years": + months += int32(scalar * 12) + case "mon", "mons": + months += int32(scalar) + case "day", "days": + days = int32(scalar) + } + } + + if len(parts)%2 == 1 { + timeParts := strings.SplitN(parts[len(parts)-1], ":", 3) + if len(timeParts) != 3 { + return fmt.Errorf("bad interval format") + } + + var negative bool + if timeParts[0][0] == '-' { + negative = true + timeParts[0] = timeParts[0][1:] + } + + hours, err := strconv.ParseInt(timeParts[0], 10, 64) + if err != nil { + return fmt.Errorf("bad interval hour format: %s", timeParts[0]) + } + + minutes, err := strconv.ParseInt(timeParts[1], 10, 64) + if err != nil { + return fmt.Errorf("bad interval minute format: %s", timeParts[1]) + } + + secondParts := strings.SplitN(timeParts[2], ".", 2) + + seconds, err := strconv.ParseInt(secondParts[0], 10, 64) + if err != nil { + return fmt.Errorf("bad interval second format: %s", secondParts[0]) + } + + var uSeconds int64 + if len(secondParts) == 2 { + uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64) + if err != nil { + return fmt.Errorf("bad interval decimal format: %s", secondParts[1]) + } + + for i := 0; i < 6-len(secondParts[1]); i++ { + uSeconds *= 10 + } + } + + microseconds = hours * microsecondsPerHour + microseconds += minutes * microsecondsPerMinute + microseconds += seconds * microsecondsPerSecond + microseconds += uSeconds + + if negative { + microseconds = -microseconds + } + } + + *dst = Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true} + return nil +} + +func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Interval{} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("Received an invalid size for a interval: %d", len(src)) + } + + microseconds := int64(binary.BigEndian.Uint64(src)) + days := int32(binary.BigEndian.Uint32(src[8:])) + months := int32(binary.BigEndian.Uint32(src[12:])) + + *dst = Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true} + return nil +} + +func (src Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if src.Months != 0 { + buf = append(buf, strconv.FormatInt(int64(src.Months), 10)...) + buf = append(buf, " mon "...) + } + + if src.Days != 0 { + buf = append(buf, strconv.FormatInt(int64(src.Days), 10)...) + buf = append(buf, " day "...) + } + + absMicroseconds := src.Microseconds + if absMicroseconds < 0 { + absMicroseconds = -absMicroseconds + buf = append(buf, '-') + } + + hours := absMicroseconds / microsecondsPerHour + minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute + seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond + microseconds := absMicroseconds % microsecondsPerSecond + + timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) + return append(buf, timeStr...), nil +} + +// EncodeBinary encodes src into w. +func (src Interval) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = pgio.AppendInt64(buf, src.Microseconds) + buf = pgio.AppendInt32(buf, src.Days) + return pgio.AppendInt32(buf, src.Months), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Interval) Scan(src interface{}) error { + if src == nil { + *dst = Interval{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Interval) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go new file mode 100644 index 00000000..844f3866 --- /dev/null +++ b/pgtype/interval_test.go @@ -0,0 +1,74 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIntervalTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "interval", []interface{}{ + &pgtype.Interval{Microseconds: 1, Valid: true}, + &pgtype.Interval{Microseconds: 1000000, Valid: true}, + &pgtype.Interval{Microseconds: 1000001, Valid: true}, + &pgtype.Interval{Microseconds: 123202800000000, Valid: true}, + &pgtype.Interval{Days: 1, Valid: true}, + &pgtype.Interval{Months: 1, Valid: true}, + &pgtype.Interval{Months: 12, Valid: true}, + &pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Valid: true}, + &pgtype.Interval{Microseconds: -1, Valid: true}, + &pgtype.Interval{Microseconds: -1000000, Valid: true}, + &pgtype.Interval{Microseconds: -1000001, Valid: true}, + &pgtype.Interval{Microseconds: -123202800000000, Valid: true}, + &pgtype.Interval{Days: -1, Valid: true}, + &pgtype.Interval{Months: -1, Valid: true}, + &pgtype.Interval{Months: -12, Valid: true}, + &pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Valid: true}, + &pgtype.Interval{}, + }) +} + +func TestIntervalNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select '1 second'::interval", + Value: &pgtype.Interval{Microseconds: 1000000, Valid: true}, + }, + { + SQL: "select '1.000001 second'::interval", + Value: &pgtype.Interval{Microseconds: 1000001, Valid: true}, + }, + { + SQL: "select '34223 hours'::interval", + Value: &pgtype.Interval{Microseconds: 123202800000000, Valid: true}, + }, + { + SQL: "select '1 day'::interval", + Value: &pgtype.Interval{Days: 1, Valid: true}, + }, + { + SQL: "select '1 month'::interval", + Value: &pgtype.Interval{Months: 1, Valid: true}, + }, + { + SQL: "select '1 year'::interval", + Value: &pgtype.Interval{Months: 12, Valid: true}, + }, + { + SQL: "select '-13 mon'::interval", + Value: &pgtype.Interval{Months: -13, Valid: true}, + }, + }) +} + +func TestIntervalLossyConversionToDuration(t *testing.T) { + interval := &pgtype.Interval{Months: 1, Days: 1, Valid: true} + var d time.Duration + err := interval.AssignTo(&d) + require.NoError(t, err) + assert.EqualValues(t, int64(2678400000000000), d.Nanoseconds()) +} diff --git a/pgtype/json.go b/pgtype/json.go new file mode 100644 index 00000000..580e8505 --- /dev/null +++ b/pgtype/json.go @@ -0,0 +1,189 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/json" + "errors" + "fmt" +) + +type JSON struct { + Bytes []byte + Valid bool +} + +func (dst *JSON) Set(src interface{}) error { + if src == nil { + *dst = JSON{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case string: + *dst = JSON{Bytes: []byte(value), Valid: true} + case *string: + if value == nil { + *dst = JSON{} + } else { + *dst = JSON{Bytes: []byte(*value), Valid: true} + } + case []byte: + if value == nil { + *dst = JSON{} + } else { + *dst = JSON{Bytes: value, Valid: true} + } + // Encode* methods are defined on *JSON. If JSON is passed directly then the + // struct itself would be encoded instead of Bytes. This is clearly a footgun + // so detect and return an error. See https://github.com/jackc/pgx/issues/350. + case JSON: + return errors.New("use pointer to pgtype.JSON instead of value") + // Same as above but for JSONB (because they share implementation) + case JSONB: + return errors.New("use pointer to pgtype.JSONB instead of value") + + default: + buf, err := json.Marshal(value) + if err != nil { + return err + } + *dst = JSON{Bytes: buf, Valid: true} + } + + return nil +} + +func (dst JSON) Get() interface{} { + if !dst.Valid { + return nil + } + + var i interface{} + err := json.Unmarshal(dst.Bytes, &i) + if err != nil { + return dst + } + return i +} + +func (src *JSON) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *string: + if src.Valid { + *v = string(src.Bytes) + } else { + return fmt.Errorf("cannot assign non-valid to %T", dst) + } + case **string: + if src.Valid { + s := string(src.Bytes) + *v = &s + return nil + } else { + *v = nil + return nil + } + case *[]byte: + if !src.Valid { + *v = nil + } else { + buf := make([]byte, len(src.Bytes)) + copy(buf, src.Bytes) + *v = buf + } + default: + data := src.Bytes + if data == nil || !src.Valid { + data = []byte("null") + } + + return json.Unmarshal(data, dst) + } + + return nil +} + +func (JSON) PreferredResultFormat() int16 { + return TextFormatCode +} + +func (dst *JSON) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = JSON{} + return nil + } + + *dst = JSON{Bytes: src, Valid: true} + return nil +} + +func (dst *JSON) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) +} + +func (JSON) PreferredParamFormat() int16 { + return TextFormatCode +} + +func (src JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, src.Bytes...), nil +} + +func (src JSON) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *JSON) Scan(src interface{}) error { + if src == nil { + *dst = JSON{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src JSON) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return src.Bytes, nil +} + +func (src JSON) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return src.Bytes, nil +} + +func (dst *JSON) UnmarshalJSON(b []byte) error { + if b == nil || string(b) == "null" { + *dst = JSON{} + } else { + *dst = JSON{Bytes: b, Valid: true} + } + return nil + +} diff --git a/pgtype/json_test.go b/pgtype/json_test.go new file mode 100644 index 00000000..c56f403f --- /dev/null +++ b/pgtype/json_test.go @@ -0,0 +1,177 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestJSONTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "json", []interface{}{ + &pgtype.JSON{Bytes: []byte("{}"), Valid: true}, + &pgtype.JSON{Bytes: []byte("null"), Valid: true}, + &pgtype.JSON{Bytes: []byte("42"), Valid: true}, + &pgtype.JSON{Bytes: []byte(`"hello"`), Valid: true}, + &pgtype.JSON{}, + }) +} + +func TestJSONSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.JSON + }{ + {source: "{}", result: pgtype.JSON{Bytes: []byte("{}"), Valid: true}}, + {source: []byte("{}"), result: pgtype.JSON{Bytes: []byte("{}"), Valid: true}}, + {source: ([]byte)(nil), result: pgtype.JSON{}}, + {source: (*string)(nil), result: pgtype.JSON{}}, + {source: []int{1, 2, 3}, result: pgtype.JSON{Bytes: []byte("[1,2,3]"), Valid: true}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Valid: true}}, + } + + for i, tt := range successfulTests { + var d pgtype.JSON + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(d, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestJSONAssignTo(t *testing.T) { + var s string + var ps *string + var b []byte + + rawStringTests := []struct { + src pgtype.JSON + dst *string + expected string + }{ + {src: pgtype.JSON{Bytes: []byte("{}"), Valid: true}, dst: &s, expected: "{}"}, + } + + for i, tt := range rawStringTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + rawBytesTests := []struct { + src pgtype.JSON + dst *[]byte + expected []byte + }{ + {src: pgtype.JSON{Bytes: []byte("{}"), Valid: true}, dst: &b, expected: []byte("{}")}, + {src: pgtype.JSON{}, dst: &b, expected: (([]byte)(nil))}, + } + + for i, tt := range rawBytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(tt.expected, *tt.dst) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + var mapDst map[string]interface{} + type structDst struct { + Name string `json:"name"` + Age int `json:"age"` + } + var strDst structDst + + unmarshalTests := []struct { + src pgtype.JSON + dst interface{} + expected interface{} + }{ + {src: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Valid: true}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.JSON{Bytes: []byte(`{"name":"John","age":42}`), Valid: true}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + } + for i, tt := range unmarshalTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.JSON + dst **string + expected *string + }{ + {src: pgtype.JSON{}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} + +func TestJSONMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.JSON + result string + }{ + {source: pgtype.JSON{}, result: "null"}, + {source: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Valid: true}, result: "{\"a\": 1}"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestJSONUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.JSON + }{ + {source: "null", result: pgtype.JSON{}}, + {source: "{\"a\": 1}", result: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.JSON + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r.Bytes) != string(tt.result.Bytes) || r.Valid != tt.result.Valid { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go new file mode 100644 index 00000000..38d56499 --- /dev/null +++ b/pgtype/jsonb.go @@ -0,0 +1,82 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +type JSONB JSON + +func (dst *JSONB) Set(src interface{}) error { + return (*JSON)(dst).Set(src) +} + +func (dst JSONB) Get() interface{} { + return (JSON)(dst).Get() +} + +func (src *JSONB) AssignTo(dst interface{}) error { + return (*JSON)(src).AssignTo(dst) +} + +func (JSONB) PreferredResultFormat() int16 { + return TextFormatCode +} + +func (dst *JSONB) DecodeText(ci *ConnInfo, src []byte) error { + return (*JSON)(dst).DecodeText(ci, src) +} + +func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = JSONB{} + return nil + } + + if len(src) == 0 { + return fmt.Errorf("jsonb too short") + } + + if src[0] != 1 { + return fmt.Errorf("unknown jsonb version number %d", src[0]) + } + + *dst = JSONB{Bytes: src[1:], Valid: true} + return nil + +} + +func (JSONB) PreferredParamFormat() int16 { + return TextFormatCode +} + +func (src JSONB) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (JSON)(src).EncodeText(ci, buf) +} + +func (src JSONB) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = append(buf, 1) + return append(buf, src.Bytes...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *JSONB) Scan(src interface{}) error { + return (*JSON)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src JSONB) Value() (driver.Value, error) { + return (JSON)(src).Value() +} + +func (src JSONB) MarshalJSON() ([]byte, error) { + return (JSON)(src).MarshalJSON() +} + +func (dst *JSONB) UnmarshalJSON(b []byte) error { + return (*JSON)(dst).UnmarshalJSON(b) +} diff --git a/pgtype/jsonb_array.go b/pgtype/jsonb_array.go new file mode 100644 index 00000000..81ed9f29 --- /dev/null +++ b/pgtype/jsonb_array.go @@ -0,0 +1,504 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type JSONBArray struct { + Elements []JSONB + Dimensions []ArrayDimension + Valid bool +} + +func (dst *JSONBArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = JSONBArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []string: + if value == nil { + *dst = JSONBArray{} + } else if len(value) == 0 { + *dst = JSONBArray{Valid: true} + } else { + elements := make([]JSONB, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = JSONBArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case [][]byte: + if value == nil { + *dst = JSONBArray{} + } else if len(value) == 0 { + *dst = JSONBArray{Valid: true} + } else { + elements := make([]JSONB, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = JSONBArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []JSONB: + if value == nil { + *dst = JSONBArray{} + } else if len(value) == 0 { + *dst = JSONBArray{Valid: true} + } else { + *dst = JSONBArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = JSONBArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for JSONBArray", src) + } + if elementsLength == 0 { + *dst = JSONBArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to JSONBArray", src) + } + + *dst = JSONBArray{ + Elements: make([]JSONB, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]JSONB, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to JSONBArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *JSONBArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to JSONBArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in JSONBArray", err) + } + index++ + + return index, nil +} + +func (dst JSONBArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *JSONBArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from JSONBArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from JSONBArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *JSONBArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = JSONBArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []JSONB + + if len(uta.Elements) > 0 { + elements = make([]JSONB, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem JSONB + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = JSONBArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *JSONBArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = JSONBArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = JSONBArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]JSONB, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = JSONBArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src JSONBArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src JSONBArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("jsonb"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "jsonb") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *JSONBArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src JSONBArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/jsonb_array_test.go b/pgtype/jsonb_array_test.go new file mode 100644 index 00000000..4f293e9e --- /dev/null +++ b/pgtype/jsonb_array_test.go @@ -0,0 +1,36 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestJSONBArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "jsonb[]", []interface{}{ + &pgtype.JSONBArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.JSONBArray{ + Elements: []pgtype.JSONB{ + {Bytes: []byte(`"foo"`), Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.JSONBArray{}, + &pgtype.JSONBArray{ + Elements: []pgtype.JSONB{ + {Bytes: []byte(`"foo"`), Valid: true}, + {Bytes: []byte("null"), Valid: true}, + {Bytes: []byte("42"), Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, + Valid: true, + }, + }) +} diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go new file mode 100644 index 00000000..41df18fa --- /dev/null +++ b/pgtype/jsonb_test.go @@ -0,0 +1,142 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestJSONBTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + if _, ok := conn.ConnInfo().DataTypeForName("jsonb"); !ok { + t.Skip("Skipping due to no jsonb type") + } + + testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ + &pgtype.JSONB{Bytes: []byte("{}"), Valid: true}, + &pgtype.JSONB{Bytes: []byte("null"), Valid: true}, + &pgtype.JSONB{Bytes: []byte("42"), Valid: true}, + &pgtype.JSONB{Bytes: []byte(`"hello"`), Valid: true}, + &pgtype.JSONB{}, + }) +} + +func TestJSONBSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.JSONB + }{ + {source: "{}", result: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}}, + {source: []byte("{}"), result: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}}, + {source: ([]byte)(nil), result: pgtype.JSONB{}}, + {source: (*string)(nil), result: pgtype.JSONB{}}, + {source: []int{1, 2, 3}, result: pgtype.JSONB{Bytes: []byte("[1,2,3]"), Valid: true}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Valid: true}}, + } + + for i, tt := range successfulTests { + var d pgtype.JSONB + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(d, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestJSONBAssignTo(t *testing.T) { + var s string + var ps *string + var b []byte + + rawStringTests := []struct { + src pgtype.JSONB + dst *string + expected string + }{ + {src: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}, dst: &s, expected: "{}"}, + } + + for i, tt := range rawStringTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + rawBytesTests := []struct { + src pgtype.JSONB + dst *[]byte + expected []byte + }{ + {src: pgtype.JSONB{Bytes: []byte("{}"), Valid: true}, dst: &b, expected: []byte("{}")}, + {src: pgtype.JSONB{}, dst: &b, expected: (([]byte)(nil))}, + } + + for i, tt := range rawBytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(tt.expected, *tt.dst) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + var mapDst map[string]interface{} + type structDst struct { + Name string `json:"name"` + Age int `json:"age"` + } + var strDst structDst + + unmarshalTests := []struct { + src pgtype.JSONB + dst interface{} + expected interface{} + }{ + {src: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Valid: true}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.JSONB{Bytes: []byte(`{"name":"John","age":42}`), Valid: true}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + } + for i, tt := range unmarshalTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.JSONB + dst **string + expected *string + }{ + {src: pgtype.JSONB{}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/pgtype/line.go b/pgtype/line.go new file mode 100644 index 00000000..c3192b2a --- /dev/null +++ b/pgtype/line.go @@ -0,0 +1,138 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Line struct { + A, B, C float64 + Valid bool +} + +func (dst *Line) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Line", src) +} + +func (dst Line) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Line) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Line{} + return nil + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Line: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 3) + if len(parts) < 3 { + return fmt.Errorf("invalid format for line") + } + + a, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + b, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + c, err := strconv.ParseFloat(parts[2], 64) + if err != nil { + return err + } + + *dst = Line{A: a, B: b, C: c, Valid: true} + return nil +} + +func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Line{} + return nil + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for Line: %v", len(src)) + } + + a := binary.BigEndian.Uint64(src) + b := binary.BigEndian.Uint64(src[8:]) + c := binary.BigEndian.Uint64(src[16:]) + + *dst = Line{ + A: math.Float64frombits(a), + B: math.Float64frombits(b), + C: math.Float64frombits(c), + Valid: true, + } + return nil +} + +func (src Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`{%s,%s,%s}`, + strconv.FormatFloat(src.A, 'f', -1, 64), + strconv.FormatFloat(src.B, 'f', -1, 64), + strconv.FormatFloat(src.C, 'f', -1, 64), + )...) + + return buf, nil +} + +func (src Line) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.A)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.B)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.C)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Line) Scan(src interface{}) error { + if src == nil { + *dst = Line{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Line) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/line_test.go b/pgtype/line_test.go new file mode 100644 index 00000000..c47f6512 --- /dev/null +++ b/pgtype/line_test.go @@ -0,0 +1,38 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestLineTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { + t.Skip("Skipping due to no line type") + } + + // line may exist but not be usable on 9.3 :( + var isPG93 bool + err := conn.QueryRow(context.Background(), "select version() ~ '9.3'").Scan(&isPG93) + if err != nil { + t.Fatal(err) + } + if isPG93 { + t.Skip("Skipping due to unimplemented line type in PG 9.3") + } + + testutil.TestSuccessfulTranscode(t, "line", []interface{}{ + &pgtype.Line{ + A: 1.23, B: 4.56, C: 7.89012345, + Valid: true, + }, + &pgtype.Line{ + A: -1.23, B: -4.56, C: -7.89, + Valid: true, + }, + &pgtype.Line{}, + }) +} diff --git a/pgtype/lseg.go b/pgtype/lseg.go new file mode 100644 index 00000000..649863ca --- /dev/null +++ b/pgtype/lseg.go @@ -0,0 +1,155 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Lseg struct { + P [2]Vec2 + Valid bool +} + +func (dst *Lseg) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Lseg", src) +} + +func (dst Lseg) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Lseg) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Lseg{} + return nil + } + + if len(src) < 11 { + return fmt.Errorf("invalid length for Lseg: %v", len(src)) + } + + str := string(src[2:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1 : len(str)-2] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true} + return nil +} + +func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Lseg{} + return nil + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for Lseg: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + *dst = Lseg{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Valid: true, + } + return nil +} + +func (src Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(src.P[0].X, 'f', -1, 64), + strconv.FormatFloat(src.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(src.P[1].X, 'f', -1, 64), + strconv.FormatFloat(src.P[1].Y, 'f', -1, 64), + )...) + + return buf, nil +} + +func (src Lseg) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Lseg) Scan(src interface{}) error { + if src == nil { + *dst = Lseg{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Lseg) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/lseg_test.go b/pgtype/lseg_test.go new file mode 100644 index 00000000..af2faf3f --- /dev/null +++ b/pgtype/lseg_test.go @@ -0,0 +1,22 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestLsegTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "lseg", []interface{}{ + &pgtype.Lseg{ + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, + Valid: true, + }, + &pgtype.Lseg{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Valid: true, + }, + &pgtype.Lseg{}, + }) +} diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go new file mode 100644 index 00000000..8d6ab720 --- /dev/null +++ b/pgtype/macaddr.go @@ -0,0 +1,160 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + "net" +) + +type Macaddr struct { + Addr net.HardwareAddr + Valid bool +} + +func (dst *Macaddr) Set(src interface{}) error { + if src == nil { + *dst = Macaddr{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case net.HardwareAddr: + addr := make(net.HardwareAddr, len(value)) + copy(addr, value) + *dst = Macaddr{Addr: addr, Valid: true} + case string: + addr, err := net.ParseMAC(value) + if err != nil { + return err + } + *dst = Macaddr{Addr: addr, Valid: true} + case *net.HardwareAddr: + if value == nil { + *dst = Macaddr{} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Macaddr{} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Macaddr", value) + } + + return nil +} + +func (dst Macaddr) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.Addr +} + +func (src *Macaddr) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *net.HardwareAddr: + *v = make(net.HardwareAddr, len(src.Addr)) + copy(*v, src.Addr) + return nil + case *string: + *v = src.Addr.String() + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Macaddr{} + return nil + } + + addr, err := net.ParseMAC(string(src)) + if err != nil { + return err + } + + *dst = Macaddr{Addr: addr, Valid: true} + return nil +} + +func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Macaddr{} + return nil + } + + if len(src) != 6 { + return fmt.Errorf("Received an invalid size for a macaddr: %d", len(src)) + } + + addr := make(net.HardwareAddr, 6) + copy(addr, src) + + *dst = Macaddr{Addr: addr, Valid: true} + + return nil +} + +func (src Macaddr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, src.Addr.String()...), nil +} + +// EncodeBinary encodes src into w. +func (src Macaddr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, src.Addr...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Macaddr) Scan(src interface{}) error { + if src == nil { + *dst = Macaddr{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Macaddr) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/macaddr_array.go b/pgtype/macaddr_array.go new file mode 100644 index 00000000..78a93a2d --- /dev/null +++ b/pgtype/macaddr_array.go @@ -0,0 +1,505 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "net" + "reflect" + + "github.com/jackc/pgio" +) + +type MacaddrArray struct { + Elements []Macaddr + Dimensions []ArrayDimension + Valid bool +} + +func (dst *MacaddrArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = MacaddrArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []net.HardwareAddr: + if value == nil { + *dst = MacaddrArray{} + } else if len(value) == 0 { + *dst = MacaddrArray{Valid: true} + } else { + elements := make([]Macaddr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = MacaddrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*net.HardwareAddr: + if value == nil { + *dst = MacaddrArray{} + } else if len(value) == 0 { + *dst = MacaddrArray{Valid: true} + } else { + elements := make([]Macaddr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = MacaddrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Macaddr: + if value == nil { + *dst = MacaddrArray{} + } else if len(value) == 0 { + *dst = MacaddrArray{Valid: true} + } else { + *dst = MacaddrArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = MacaddrArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for MacaddrArray", src) + } + if elementsLength == 0 { + *dst = MacaddrArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to MacaddrArray", src) + } + + *dst = MacaddrArray{ + Elements: make([]Macaddr, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Macaddr, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to MacaddrArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *MacaddrArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to MacaddrArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in MacaddrArray", err) + } + index++ + + return index, nil +} + +func (dst MacaddrArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *MacaddrArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]net.HardwareAddr: + *v = make([]net.HardwareAddr, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*net.HardwareAddr: + *v = make([]*net.HardwareAddr, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from MacaddrArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from MacaddrArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *MacaddrArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = MacaddrArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Macaddr + + if len(uta.Elements) > 0 { + elements = make([]Macaddr, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Macaddr + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = MacaddrArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *MacaddrArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = MacaddrArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = MacaddrArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Macaddr, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = MacaddrArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src MacaddrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("macaddr"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "macaddr") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *MacaddrArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src MacaddrArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/macaddr_array_test.go b/pgtype/macaddr_array_test.go new file mode 100644 index 00000000..a4a55cb0 --- /dev/null +++ b/pgtype/macaddr_array_test.go @@ -0,0 +1,262 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestMacaddrArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "macaddr[]", []interface{}{ + &pgtype.MacaddrArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.MacaddrArray{}, + }) +} + +func TestMacaddrArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.MacaddrArray + }{ + { + source: []net.HardwareAddr{mustParseMacaddr(t, "01:23:45:67:89:ab")}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]net.HardwareAddr)(nil)), + result: pgtype.MacaddrArray{}, + }, + { + source: [][]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Valid: true}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Valid: true}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Valid: true}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + result: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Valid: true}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Valid: true}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Valid: true}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.MacaddrArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestMacaddrArrayAssignTo(t *testing.T) { + var macaddrSlice []net.HardwareAddr + var macaddrSliceDim2 [][]net.HardwareAddr + var macaddrSliceDim4 [][][][]net.HardwareAddr + var macaddrArrayDim2 [2][1]net.HardwareAddr + var macaddrArrayDim4 [2][1][1][3]net.HardwareAddr + + simpleTests := []struct { + src pgtype.MacaddrArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &macaddrSlice, + expected: []net.HardwareAddr{mustParseMacaddr(t, "01:23:45:67:89:ab")}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &macaddrSlice, + expected: []net.HardwareAddr{nil}, + }, + { + src: pgtype.MacaddrArray{}, + dst: &macaddrSlice, + expected: (([]net.HardwareAddr)(nil)), + }, + { + src: pgtype.MacaddrArray{Valid: true}, + dst: &macaddrSlice, + expected: []net.HardwareAddr{}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &macaddrSliceDim2, + expected: [][]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Valid: true}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Valid: true}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Valid: true}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &macaddrSliceDim4, + expected: [][][][]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &macaddrArrayDim2, + expected: [2][1]net.HardwareAddr{ + {mustParseMacaddr(t, "01:23:45:67:89:ab")}, + {mustParseMacaddr(t, "cd:ef:01:23:45:67")}}, + }, + { + src: pgtype.MacaddrArray{ + Elements: []pgtype.Macaddr{ + {Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + {Addr: mustParseMacaddr(t, "cd:ef:01:23:45:67"), Valid: true}, + {Addr: mustParseMacaddr(t, "89:ab:cd:ef:01:23"), Valid: true}, + {Addr: mustParseMacaddr(t, "45:67:89:ab:cd:ef"), Valid: true}, + {Addr: mustParseMacaddr(t, "fe:dc:ba:98:76:54"), Valid: true}, + {Addr: mustParseMacaddr(t, "32:10:fe:dc:ba:98"), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &macaddrArrayDim4, + expected: [2][1][1][3]net.HardwareAddr{ + {{{ + mustParseMacaddr(t, "01:23:45:67:89:ab"), + mustParseMacaddr(t, "cd:ef:01:23:45:67"), + mustParseMacaddr(t, "89:ab:cd:ef:01:23")}}}, + {{{ + mustParseMacaddr(t, "45:67:89:ab:cd:ef"), + mustParseMacaddr(t, "fe:dc:ba:98:76:54"), + mustParseMacaddr(t, "32:10:fe:dc:ba:98")}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go new file mode 100644 index 00000000..dc475c41 --- /dev/null +++ b/pgtype/macaddr_test.go @@ -0,0 +1,78 @@ +package pgtype_test + +import ( + "bytes" + "net" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestMacaddrTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "macaddr", []interface{}{ + &pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + &pgtype.Macaddr{}, + }) +} + +func TestMacaddrSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Macaddr + }{ + { + source: mustParseMacaddr(t, "01:23:45:67:89:ab"), + result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + }, + { + source: "01:23:45:67:89:ab", + result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Macaddr + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestMacaddrAssignTo(t *testing.T) { + { + src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true} + var dst net.HardwareAddr + expected := mustParseMacaddr(t, "01:23:45:67:89:ab") + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare([]byte(dst), []byte(expected)) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true} + var dst string + expected := "01:23:45:67:89:ab" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } +} diff --git a/pgtype/name.go b/pgtype/name.go new file mode 100644 index 00000000..7ce8d25e --- /dev/null +++ b/pgtype/name.go @@ -0,0 +1,58 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// Name is a type used for PostgreSQL's special 63-byte +// name data type, used for identifiers like table names. +// The pg_class.relname column is a good example of where the +// name data type is used. +// +// Note that the underlying Go data type of pgx.Name is string, +// so there is no way to enforce the 63-byte length. Inputting +// a longer name into PostgreSQL will result in silent truncation +// to 63 bytes. +// +// Also, if you have custom-compiled PostgreSQL and set +// NAMEDATALEN to a different value, obviously that number of +// bytes applies, rather than the default 63. +type Name Text + +func (dst *Name) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst Name) Get() interface{} { + return (Text)(dst).Get() +} + +func (src *Name) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Name) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (src Name) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeText(ci, buf) +} + +func (src Name) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Name) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Name) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/pgtype/name_test.go b/pgtype/name_test.go new file mode 100644 index 00000000..5f429d83 --- /dev/null +++ b/pgtype/name_test.go @@ -0,0 +1,98 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestNameTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "name", []interface{}{ + &pgtype.Name{String: "", Valid: true}, + &pgtype.Name{String: "foo", Valid: true}, + &pgtype.Name{}, + }) +} + +func TestNameSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Name + }{ + {source: "foo", result: pgtype.Name{String: "foo", Valid: true}}, + {source: _string("bar"), result: pgtype.Name{String: "bar", Valid: true}}, + {source: (*string)(nil), result: pgtype.Name{}}, + } + + for i, tt := range successfulTests { + var d pgtype.Name + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestNameAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.Name + dst interface{} + expected interface{} + }{ + {src: pgtype.Name{String: "foo", Valid: true}, dst: &s, expected: "foo"}, + {src: pgtype.Name{}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Name + dst interface{} + expected interface{} + }{ + {src: pgtype.Name{String: "foo", Valid: true}, dst: &ps, expected: "foo"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Name + dst interface{} + }{ + {src: pgtype.Name{}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/new_pg_value.erb b/pgtype/new_pg_value.erb new file mode 100644 index 00000000..71a0da7f --- /dev/null +++ b/pgtype/new_pg_value.erb @@ -0,0 +1,37 @@ +package pgtype + +<% skip_binary ||= false %> +<% skip_text ||= false %> +<% prefer_text_format ||= false %> + +func (<%= go_type %>) BinaryFormatSupported() bool { + return true +} + +func (<%= go_type %>) TextFormatSupported() bool { + return true +} + +func (<%= go_type %>) PreferredFormat() int16 { + return <%= prefer_text_format ? "Text" : "Binary" %>FormatCode +} + +func (dst *<%= go_type %>) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + <% if skip_binary %> return fmt.Errorf("binary format not supported for %T", dst) <% else %> return dst.DecodeBinary(ci, src) <% end %> + case TextFormatCode: + <% if skip_text %> return fmt.Errorf("text format not supported for %T", dst) <% else %> return dst.DecodeText(ci, src) <% end %> + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src <%= go_type %>) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + <% if skip_binary %>return nil, fmt.Errorf("binary format not supported for %T", src)<% else %>return src.EncodeBinary(ci, buf)<% end %> + case TextFormatCode: + <% if skip_text %>return nil, fmt.Errorf("text format not supported for %T", src)<% else %>return src.EncodeText(ci, buf)<% end %> + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/new_pg_value_gen.sh b/pgtype/new_pg_value_gen.sh new file mode 100644 index 00000000..3dad08de --- /dev/null +++ b/pgtype/new_pg_value_gen.sh @@ -0,0 +1,45 @@ +erb go_type=ACLItem skip_binary=true prefer_text_format=true new_pg_value.erb > zzz.aclitem.go +erb go_type=Bit new_pg_value.erb > zzz.bit.go +erb go_type=Bool new_pg_value.erb > zzz.bool.go +erb go_type=Box new_pg_value.erb > zzz.box.go +erb go_type=BPChar prefer_text_format=true new_pg_value.erb > zzz.bpchar.go +erb go_type=Bytea new_pg_value.erb > zzz.bytea.go +erb go_type=CID new_pg_value.erb > zzz.cid.go +erb go_type=CIDR new_pg_value.erb > zzz.cidr.go +erb go_type=Circle new_pg_value.erb > zzz.circle.go +erb go_type=Date new_pg_value.erb > zzz.date.go +erb go_type=Float4 new_pg_value.erb > zzz.float4.go +erb go_type=Float8 new_pg_value.erb > zzz.float8.go +erb go_type=GenericBinary skip_text=true new_pg_value.erb > zzz.generic_binary.go +erb go_type=GenericText skip_binary=true prefer_text_format=true new_pg_value.erb > zzz.generic_text.go +erb go_type=Hstore new_pg_value.erb > zzz.hstore.go +erb go_type=Inet new_pg_value.erb > zzz.inet.go +erb go_type=Int2 new_pg_value.erb > zzz.int2.go +erb go_type=Int4 new_pg_value.erb > zzz.int4.go +erb go_type=Int8 new_pg_value.erb > zzz.int8.go +erb go_type=Interval new_pg_value.erb > zzz.interval.go +erb go_type=JSON prefer_text_format=true new_pg_value.erb > zzz.json.go +erb go_type=JSONB prefer_text_format=true new_pg_value.erb > zzz.jsonb.go +erb go_type=Line new_pg_value.erb > zzz.line.go +erb go_type=Lseg new_pg_value.erb > zzz.lseg.go +erb go_type=Macaddr new_pg_value.erb > zzz.macadder.go +erb go_type=Name new_pg_value.erb > zzz.name.go +erb go_type=Numeric new_pg_value.erb > zzz.numeric.go +erb go_type=OIDValue new_pg_value.erb > zzz.oid_value.go +erb go_type=OID new_pg_value.erb > zzz.oid.go +erb go_type=Path new_pg_value.erb > zzz.path.go +erb go_type=pguint32 new_pg_value.erb > zzz.pguint32.go +erb go_type=Point new_pg_value.erb > zzz.point.go +erb go_type=Polygon new_pg_value.erb > zzz.polygon.go +erb go_type=QChar skip_text=true new_pg_value.erb > zzz.qchar.go +erb go_type=Text prefer_text_format=true new_pg_value.erb > zzz.text.go +erb go_type=TID new_pg_value.erb > zzz.tid.go +erb go_type=Time new_pg_value.erb > zzz.time.go +erb go_type=Timestamp new_pg_value.erb > zzz.timestamp.go +erb go_type=Timestamptz new_pg_value.erb > zzz.timestamptz.go +# erb go_type=Unknown new_pg_value.erb > zzz.unknown.go +erb go_type=UUID new_pg_value.erb > zzz.uuid.go +erb go_type=Varbit new_pg_value.erb > zzz.varbit.go +erb go_type=Varchar prefer_text_format=true new_pg_value.erb > zzz.varchar.go +erb go_type=XID new_pg_value.erb > zzz.xid.go +goimports -w zzz.* diff --git a/pgtype/numeric.go b/pgtype/numeric.go new file mode 100644 index 00000000..b24f433c --- /dev/null +++ b/pgtype/numeric.go @@ -0,0 +1,848 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "math/big" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 +const nbase = 10000 + +const ( + pgNumericNaN = 0x00000000c0000000 + pgNumericNaNSign = 0xc000 + + pgNumericPosInf = 0x00000000d0000000 + pgNumericPosInfSign = 0xd000 + + pgNumericNegInf = 0x00000000f0000000 + pgNumericNegInfSign = 0xf000 +) + +var big0 *big.Int = big.NewInt(0) +var big1 *big.Int = big.NewInt(1) +var big10 *big.Int = big.NewInt(10) +var big100 *big.Int = big.NewInt(100) +var big1000 *big.Int = big.NewInt(1000) + +var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8) +var bigMinInt8 *big.Int = big.NewInt(math.MinInt8) +var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16) +var bigMinInt16 *big.Int = big.NewInt(math.MinInt16) +var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32) +var bigMinInt32 *big.Int = big.NewInt(math.MinInt32) +var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64) +var bigMinInt64 *big.Int = big.NewInt(math.MinInt64) +var bigMaxInt *big.Int = big.NewInt(int64(maxInt)) +var bigMinInt *big.Int = big.NewInt(int64(minInt)) + +var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8) +var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16) +var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32) +var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64)) +var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint)) + +var bigNBase *big.Int = big.NewInt(nbase) +var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) +var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) +var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) + +type Numeric struct { + Int *big.Int + Exp int32 + NaN bool + InfinityModifier InfinityModifier + Valid bool + + NumericDecoderWrapper func(interface{}) NumericDecoder + Getter func(Numeric) interface{} +} + +func (n *Numeric) NewTypeValue() Value { + return &Numeric{ + NumericDecoderWrapper: n.NumericDecoderWrapper, + Getter: n.Getter, + } +} + +func (n *Numeric) TypeName() string { + return "numeric" +} + +func (dst *Numeric) setNil() { + dst.Int = nil + dst.Exp = 0 + dst.NaN = false + dst.Valid = false +} + +func (dst *Numeric) setNaN() { + dst.Int = nil + dst.Exp = 0 + dst.NaN = true + dst.Valid = true +} + +func (dst *Numeric) setNumber(i *big.Int, exp int32) { + dst.Int = i + dst.Exp = exp + dst.NaN = false + dst.Valid = true +} + +func (dst *Numeric) Set(src interface{}) error { + if src == nil { + dst.setNil() + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case float32: + if math.IsNaN(float64(value)) { + dst.setNaN() + return nil + } else if math.IsInf(float64(value), 1) { + *dst = Numeric{Valid: true, InfinityModifier: Infinity} + return nil + } else if math.IsInf(float64(value), -1) { + *dst = Numeric{Valid: true, InfinityModifier: NegativeInfinity} + return nil + } + num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) + if err != nil { + return err + } + dst.setNumber(num, exp) + case float64: + if math.IsNaN(value) { + dst.setNaN() + return nil + } else if math.IsInf(value, 1) { + *dst = Numeric{Valid: true, InfinityModifier: Infinity} + return nil + } else if math.IsInf(value, -1) { + *dst = Numeric{Valid: true, InfinityModifier: NegativeInfinity} + return nil + } + num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) + if err != nil { + return err + } + dst.setNumber(num, exp) + case int8: + dst.setNumber(big.NewInt(int64(value)), 0) + case uint8: + dst.setNumber(big.NewInt(int64(value)), 0) + case int16: + dst.setNumber(big.NewInt(int64(value)), 0) + case uint16: + dst.setNumber(big.NewInt(int64(value)), 0) + case int32: + dst.setNumber(big.NewInt(int64(value)), 0) + case uint32: + dst.setNumber(big.NewInt(int64(value)), 0) + case int64: + dst.setNumber(big.NewInt(value), 0) + case uint64: + dst.setNumber((&big.Int{}).SetUint64(value), 0) + case int: + dst.setNumber(big.NewInt(int64(value)), 0) + case uint: + dst.setNumber((&big.Int{}).SetUint64(uint64(value)), 0) + case string: + num, exp, err := parseNumericString(value) + if err != nil { + return err + } + dst.setNumber(num, exp) + case *float64: + if value == nil { + dst.setNil() + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + dst.setNil() + } else { + return dst.Set(*value) + } + case *int8: + if value == nil { + dst.setNil() + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + dst.setNil() + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + dst.setNil() + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + dst.setNil() + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + dst.setNil() + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + dst.setNil() + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + dst.setNil() + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + dst.setNil() + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + dst.setNil() + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + dst.setNil() + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + dst.setNil() + } else { + return dst.Set(*value) + } + case InfinityModifier: + *dst = Numeric{InfinityModifier: value, Valid: true} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + return nil +} + +func (dst Numeric) Get() interface{} { + if dst.Getter != nil { + return dst.Getter(dst) + } + + if !dst.Valid { + return nil + } + + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst +} + +type NumericDecoder interface { + DecodeNumeric(*Numeric) error +} + +func (src *Numeric) AssignTo(dst interface{}) error { + if d, ok := dst.(NumericDecoder); ok { + return d.DecodeNumeric(src) + } else { + if src.NumericDecoderWrapper != nil { + d = src.NumericDecoderWrapper(dst) + if d != nil { + return d.DecodeNumeric(src) + } + } + } + + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *float32: + f, err := src.toFloat64() + if err != nil { + return err + } + return float64AssignTo(f, src.Valid, dst) + case *float64: + f, err := src.toFloat64() + if err != nil { + return err + } + return float64AssignTo(f, src.Valid, dst) + case *int: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int(normalizedInt.Int64()) + case *int8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt8) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt8) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int8(normalizedInt.Int64()) + case *int16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt16) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt16) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int16(normalizedInt.Int64()) + case *int32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt32) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt32) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int32(normalizedInt.Int64()) + case *int64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt64) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt64) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Int64() + case *uint: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint(normalizedInt.Uint64()) + case *uint8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint8) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint8(normalizedInt.Uint64()) + case *uint16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint16) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint16(normalizedInt.Uint64()) + case *uint32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint32) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint32(normalizedInt.Uint64()) + case *uint64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint64) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Uint64() + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + + return nil +} + +func (dst *Numeric) toBigInt() (*big.Int, error) { + if dst.Exp == 0 { + return dst.Int, nil + } + + num := &big.Int{} + num.Set(dst.Int) + if dst.Exp > 0 { + mul := &big.Int{} + mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil) + num.Mul(num, mul) + return num, nil + } + + div := &big.Int{} + div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil) + remainder := &big.Int{} + num.DivMod(num, div, remainder) + if remainder.Cmp(big0) != 0 { + return nil, fmt.Errorf("cannot convert %v to integer", dst) + } + return num, nil +} + +func (src *Numeric) toFloat64() (float64, error) { + if src.NaN { + return math.NaN(), nil + } else if src.InfinityModifier == Infinity { + return math.Inf(1), nil + } else if src.InfinityModifier == NegativeInfinity { + return math.Inf(-1), nil + } + + buf := make([]byte, 0, 32) + + buf = append(buf, src.Int.String()...) + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) + + f, err := strconv.ParseFloat(string(buf), 64) + if err != nil { + return 0, err + } + return f, nil +} + +func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + dst.setNil() + return nil + } + + if string(src) == "NaN" { + dst.setNaN() + return nil + } else if string(src) == "Infinity" { + *dst = Numeric{Valid: true, InfinityModifier: Infinity} + return nil + } else if string(src) == "-Infinity" { + *dst = Numeric{Valid: true, InfinityModifier: NegativeInfinity} + return nil + } + + num, exp, err := parseNumericString(string(src)) + if err != nil { + return err + } + + dst.setNumber(num, exp) + return nil +} + +func parseNumericString(str string) (n *big.Int, exp int32, err error) { + parts := strings.SplitN(str, ".", 2) + digits := strings.Join(parts, "") + + if len(parts) > 1 { + exp = int32(-len(parts[1])) + } else { + for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' { + digits = digits[:len(digits)-1] + exp++ + } + } + + accum := &big.Int{} + if _, ok := accum.SetString(digits, 10); !ok { + return nil, 0, fmt.Errorf("%s is not a number", str) + } + + return accum, exp, nil +} + +func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + dst.setNil() + return nil + } + + if len(src) < 8 { + return fmt.Errorf("numeric incomplete %v", src) + } + + rp := 0 + ndigits := binary.BigEndian.Uint16(src[rp:]) + rp += 2 + weight := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + sign := binary.BigEndian.Uint16(src[rp:]) + rp += 2 + dscale := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if sign == pgNumericNaNSign { + dst.setNaN() + return nil + } else if sign == pgNumericPosInfSign { + *dst = Numeric{Valid: true, InfinityModifier: Infinity} + return nil + } else if sign == pgNumericNegInfSign { + *dst = Numeric{Valid: true, InfinityModifier: NegativeInfinity} + return nil + } + + if ndigits == 0 { + dst.setNumber(big.NewInt(0), 0) + return nil + } + + if len(src[rp:]) < int(ndigits)*2 { + return fmt.Errorf("numeric incomplete %v", src) + } + + accum := &big.Int{} + + for i := 0; i < int(ndigits+3)/4; i++ { + int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:]) + rp += bytesRead + + if i > 0 { + var mul *big.Int + switch digitsRead { + case 1: + mul = bigNBase + case 2: + mul = bigNBaseX2 + case 3: + mul = bigNBaseX3 + case 4: + mul = bigNBaseX4 + default: + return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) + } + accum.Mul(accum, mul) + } + + accum.Add(accum, big.NewInt(int64accum)) + } + + exp := (int32(weight) - int32(ndigits) + 1) * 4 + + if dscale > 0 { + fracNBaseDigits := int16(int32(ndigits) - int32(weight) - 1) + fracDecimalDigits := fracNBaseDigits * 4 + + if dscale > fracDecimalDigits { + multCount := int(dscale - fracDecimalDigits) + for i := 0; i < multCount; i++ { + accum.Mul(accum, big10) + exp-- + } + } else if dscale < fracDecimalDigits { + divCount := int(fracDecimalDigits - dscale) + for i := 0; i < divCount; i++ { + accum.Div(accum, big10) + exp++ + } + } + } + + reduced := &big.Int{} + remainder := &big.Int{} + if exp >= 0 { + for { + reduced.DivMod(accum, big10, remainder) + if remainder.Cmp(big0) != 0 { + break + } + accum.Set(reduced) + exp++ + } + } + + if sign != 0 { + accum.Neg(accum) + } + + dst.setNumber(accum, exp) + + return nil + +} + +func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { + digits := len(src) / 2 + if digits > 4 { + digits = 4 + } + + rp := 0 + + for i := 0; i < digits; i++ { + if i > 0 { + accum *= nbase + } + accum += int64(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return accum, rp, digits +} + +func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if src.NaN { + buf = append(buf, "NaN"...) + return buf, nil + } else if src.InfinityModifier == Infinity { + buf = append(buf, "Infinity"...) + return buf, nil + } else if src.InfinityModifier == NegativeInfinity { + buf = append(buf, "-Infinity"...) + return buf, nil + } + + buf = append(buf, src.Int.String()...) + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) + return buf, nil +} + +func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if src.NaN { + buf = pgio.AppendUint64(buf, pgNumericNaN) + return buf, nil + } else if src.InfinityModifier == Infinity { + buf = pgio.AppendUint64(buf, pgNumericPosInf) + return buf, nil + } else if src.InfinityModifier == NegativeInfinity { + buf = pgio.AppendUint64(buf, pgNumericNegInf) + return buf, nil + } + + var sign int16 + if src.Int.Cmp(big0) < 0 { + sign = 16384 + } + + absInt := &big.Int{} + wholePart := &big.Int{} + fracPart := &big.Int{} + remainder := &big.Int{} + absInt.Abs(src.Int) + + // Normalize absInt and exp to where exp is always a multiple of 4. This makes + // converting to 16-bit base 10,000 digits easier. + var exp int32 + switch src.Exp % 4 { + case 1, -3: + exp = src.Exp - 1 + absInt.Mul(absInt, big10) + case 2, -2: + exp = src.Exp - 2 + absInt.Mul(absInt, big100) + case 3, -1: + exp = src.Exp - 3 + absInt.Mul(absInt, big1000) + default: + exp = src.Exp + } + + if exp < 0 { + divisor := &big.Int{} + divisor.Exp(big10, big.NewInt(int64(-exp)), nil) + wholePart.DivMod(absInt, divisor, fracPart) + fracPart.Add(fracPart, divisor) + } else { + wholePart = absInt + } + + var wholeDigits, fracDigits []int16 + + for wholePart.Cmp(big0) != 0 { + wholePart.DivMod(wholePart, bigNBase, remainder) + wholeDigits = append(wholeDigits, int16(remainder.Int64())) + } + + if fracPart.Cmp(big0) != 0 { + for fracPart.Cmp(big1) != 0 { + fracPart.DivMod(fracPart, bigNBase, remainder) + fracDigits = append(fracDigits, int16(remainder.Int64())) + } + } + + buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) + + var weight int16 + if len(wholeDigits) > 0 { + weight = int16(len(wholeDigits) - 1) + if exp > 0 { + weight += int16(exp / 4) + } + } else { + weight = int16(exp/4) - 1 + int16(len(fracDigits)) + } + buf = pgio.AppendInt16(buf, weight) + + buf = pgio.AppendInt16(buf, sign) + + var dscale int16 + if src.Exp < 0 { + dscale = int16(-src.Exp) + } + buf = pgio.AppendInt16(buf, dscale) + + for i := len(wholeDigits) - 1; i >= 0; i-- { + buf = pgio.AppendInt16(buf, wholeDigits[i]) + } + + for i := len(fracDigits) - 1; i >= 0; i-- { + buf = pgio.AppendInt16(buf, fracDigits[i]) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numeric) Scan(src interface{}) error { + if src == nil { + dst.setNil() + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Numeric) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + + return string(buf), nil +} + +func (src Numeric) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + if src.NaN { + return []byte(`"NaN"`), nil + } + + intStr := src.Int.String() + buf := &bytes.Buffer{} + exp := int(src.Exp) + if exp > 0 { + buf.WriteString(intStr) + for i := 0; i < exp; i++ { + buf.WriteByte('0') + } + } else if exp < 0 { + if len(intStr) <= -exp { + buf.WriteString("0.") + leadingZeros := -exp - len(intStr) + for i := 0; i < leadingZeros; i++ { + buf.WriteByte('0') + } + buf.WriteString(intStr) + } else if len(intStr) > -exp { + dpPos := len(intStr) + exp + buf.WriteString(intStr[:dpPos]) + buf.WriteByte('.') + buf.WriteString(intStr[dpPos:]) + } + } else { + buf.WriteString(intStr) + } + + return buf.Bytes(), nil +} diff --git a/pgtype/numeric_array.go b/pgtype/numeric_array.go new file mode 100644 index 00000000..3e9298b6 --- /dev/null +++ b/pgtype/numeric_array.go @@ -0,0 +1,672 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type NumericArray struct { + Elements []Numeric + Dimensions []ArrayDimension + Valid bool +} + +func (dst *NumericArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = NumericArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []float32: + if value == nil { + *dst = NumericArray{} + } else if len(value) == 0 { + *dst = NumericArray{Valid: true} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*float32: + if value == nil { + *dst = NumericArray{} + } else if len(value) == 0 { + *dst = NumericArray{Valid: true} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []float64: + if value == nil { + *dst = NumericArray{} + } else if len(value) == 0 { + *dst = NumericArray{Valid: true} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*float64: + if value == nil { + *dst = NumericArray{} + } else if len(value) == 0 { + *dst = NumericArray{Valid: true} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []int64: + if value == nil { + *dst = NumericArray{} + } else if len(value) == 0 { + *dst = NumericArray{Valid: true} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*int64: + if value == nil { + *dst = NumericArray{} + } else if len(value) == 0 { + *dst = NumericArray{Valid: true} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []uint64: + if value == nil { + *dst = NumericArray{} + } else if len(value) == 0 { + *dst = NumericArray{Valid: true} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*uint64: + if value == nil { + *dst = NumericArray{} + } else if len(value) == 0 { + *dst = NumericArray{Valid: true} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Numeric: + if value == nil { + *dst = NumericArray{} + } else if len(value) == 0 { + *dst = NumericArray{Valid: true} + } else { + *dst = NumericArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = NumericArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for NumericArray", src) + } + if elementsLength == 0 { + *dst = NumericArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to NumericArray", src) + } + + *dst = NumericArray{ + Elements: make([]Numeric, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Numeric, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to NumericArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *NumericArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to NumericArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in NumericArray", err) + } + index++ + + return index, nil +} + +func (dst NumericArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *NumericArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]float32: + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float32: + *v = make([]*float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]float64: + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float64: + *v = make([]*float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from NumericArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from NumericArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = NumericArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Numeric + + if len(uta.Elements) > 0 { + elements = make([]Numeric, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Numeric + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = NumericArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = NumericArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = NumericArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Numeric, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = NumericArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("numeric"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "numeric") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *NumericArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src NumericArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/numeric_array_test.go b/pgtype/numeric_array_test.go new file mode 100644 index 00000000..ee36d1a7 --- /dev/null +++ b/pgtype/numeric_array_test.go @@ -0,0 +1,305 @@ +package pgtype_test + +import ( + "math" + "math/big" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestNumericArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "numeric[]", []interface{}{ + &pgtype.NumericArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.NumericArray{}, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Valid: true}, + {Int: big.NewInt(2), Valid: true}, + {Int: big.NewInt(3), Valid: true}, + {Int: big.NewInt(4), Valid: true}, + {}, + {Int: big.NewInt(6), Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Valid: true}, + {Int: big.NewInt(2), Valid: true}, + {Int: big.NewInt(3), Valid: true}, + {Int: big.NewInt(4), Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestNumericArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.NumericArray + }{ + { + source: []float32{1}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []float32{float32(math.Copysign(0, -1))}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(0), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []float64{1}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []float64{math.Copysign(0, -1)}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(0), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]float32)(nil)), + result: pgtype.NumericArray{}, + }, + { + source: [][]float32{{1}, {2}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Valid: true}, + {Int: big.NewInt(2), Valid: true}, + {Int: big.NewInt(3), Valid: true}, + {Int: big.NewInt(4), Valid: true}, + {Int: big.NewInt(5), Valid: true}, + {Int: big.NewInt(6), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]float32{{1}, {2}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Valid: true}, + {Int: big.NewInt(2), Valid: true}, + {Int: big.NewInt(3), Valid: true}, + {Int: big.NewInt(4), Valid: true}, + {Int: big.NewInt(5), Valid: true}, + {Int: big.NewInt(6), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.NumericArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericArrayAssignTo(t *testing.T) { + var float32Slice []float32 + var float64Slice []float64 + var float32SliceDim2 [][]float32 + var float32SliceDim4 [][][][]float32 + var float32ArrayDim2 [2][1]float32 + var float32ArrayDim4 [2][1][1][3]float32 + + simpleTests := []struct { + src pgtype.NumericArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &float32Slice, + expected: []float32{1}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &float64Slice, + expected: []float64{1}, + }, + { + src: pgtype.NumericArray{}, + dst: &float32Slice, + expected: (([]float32)(nil)), + }, + { + src: pgtype.NumericArray{Valid: true}, + dst: &float32Slice, + expected: []float32{}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &float32SliceDim2, + expected: [][]float32{{1}, {2}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Valid: true}, + {Int: big.NewInt(2), Valid: true}, + {Int: big.NewInt(3), Valid: true}, + {Int: big.NewInt(4), Valid: true}, + {Int: big.NewInt(5), Valid: true}, + {Int: big.NewInt(6), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &float32SliceDim4, + expected: [][][][]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &float32ArrayDim2, + expected: [2][1]float32{{1}, {2}}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + {Int: big.NewInt(1), Valid: true}, + {Int: big.NewInt(2), Valid: true}, + {Int: big.NewInt(3), Valid: true}, + {Int: big.NewInt(4), Valid: true}, + {Int: big.NewInt(5), Valid: true}, + {Int: big.NewInt(6), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &float32ArrayDim4, + expected: [2][1][1][3]float32{{{{1, 2, 3}}}, {{{4, 5, 6}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.NumericArray + dst interface{} + }{ + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &float32Slice, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &float32ArrayDim2, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &float32Slice, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Valid: true}, {Int: big.NewInt(2), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &float32ArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go new file mode 100644 index 00000000..7f0734d0 --- /dev/null +++ b/pgtype/numeric_test.go @@ -0,0 +1,447 @@ +package pgtype_test + +import ( + "context" + "encoding/json" + "math" + "math/big" + "math/rand" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +// For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) +func numericEqual(left, right *pgtype.Numeric) bool { + return left.Valid == right.Valid && + left.Exp == right.Exp && + ((left.Int == nil && right.Int == nil) || (left.Int != nil && right.Int != nil && left.Int.Cmp(right.Int) == 0)) && + left.NaN == right.NaN +} + +// For test purposes only. +func numericNormalizedEqual(left, right *pgtype.Numeric) bool { + if left.Valid != right.Valid { + return false + } + + normLeft := &pgtype.Numeric{Int: (&big.Int{}).Set(left.Int), Valid: left.Valid} + normRight := &pgtype.Numeric{Int: (&big.Int{}).Set(right.Int), Valid: right.Valid} + + if left.Exp < right.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(right.Exp-left.Exp)), nil) + normRight.Int.Mul(normRight.Int, mul) + } else if left.Exp > right.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(left.Exp-right.Exp)), nil) + normLeft.Int.Mul(normLeft.Int, mul) + } + + return normLeft.Int.Cmp(normRight.Int) == 0 +} + +func mustParseBigInt(t *testing.T, src string) *big.Int { + i := &big.Int{} + if _, ok := i.SetString(src, 10); !ok { + t.Fatalf("could not parse big.Int: %s", src) + } + return i +} + +func TestNumericNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select '0'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Valid: true}, + }, + { + SQL: "select '1'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Valid: true}, + }, + { + SQL: "select '10.00'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Valid: true}, + }, + { + SQL: "select '1e-3'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Valid: true}, + }, + { + SQL: "select '-1'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Valid: true}, + }, + { + SQL: "select '10000'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Valid: true}, + }, + { + SQL: "select '3.14'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Valid: true}, + }, + { + SQL: "select '1.1'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Valid: true}, + }, + { + SQL: "select '100010001'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Valid: true}, + }, + { + SQL: "select '100010001.0001'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Valid: true}, + }, + { + SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", + Value: &pgtype.Numeric{ + Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), + Exp: -41, + Valid: true, + }, + }, + { + SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", + Value: &pgtype.Numeric{ + Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), + Exp: -196, + Valid: true, + }, + }, + { + SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", + Value: &pgtype.Numeric{ + Int: mustParseBigInt(t, "123"), + Exp: -186, + Valid: true, + }, + }, + }) +} + +func TestNumericTranscode(t *testing.T) { + max := new(big.Int).Exp(big.NewInt(10), big.NewInt(147454), nil) + max.Add(max, big.NewInt(1)) + longestNumeric := &pgtype.Numeric{Int: max, Exp: -16383, Valid: true} + + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + &pgtype.Numeric{NaN: true, Valid: true}, + &pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, + &pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, + + &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 6, Valid: true}, + + // preserves significant zeroes + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -1, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -2, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -3, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -4, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -5, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -6, Valid: true}, + + &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -7, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -8, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -9, Valid: true}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -1500, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3723409723490243842378942378901237502734019231380123"), Exp: 81, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "723409723490243842378942378901237502734019231380123"), Exp: 82, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409723490243842378942378901237502734019231380123"), Exp: 83, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409723490243842378942378901237502734019231380123"), Exp: 84, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3423409823409243892349028349023482934092340892390101"), Exp: -91, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "423409823409243892349028349023482934092340892390101"), Exp: -92, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409823409243892349028349023482934092340892390101"), Exp: -93, Valid: true}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409823409243892349028349023482934092340892390101"), Exp: -94, Valid: true}, + + longestNumeric, + + &pgtype.Numeric{}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Numeric) + b := bb.(pgtype.Numeric) + + return numericEqual(&a, &b) + }) + +} + +func TestNumericTranscodeFuzz(t *testing.T) { + r := rand.New(rand.NewSource(0)) + max := &big.Int{} + max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) + + values := make([]interface{}, 0, 2000) + for i := 0; i < 10; i++ { + for j := -50; j < 50; j++ { + num := (&big.Int{}).Rand(r, max) + negNum := &big.Int{} + negNum.Neg(num) + values = append(values, &pgtype.Numeric{Int: num, Exp: int32(j), Valid: true}) + values = append(values, &pgtype.Numeric{Int: negNum, Exp: int32(j), Valid: true}) + } + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, + func(aa, bb interface{}) bool { + a := aa.(pgtype.Numeric) + b := bb.(pgtype.Numeric) + + return numericNormalizedEqual(&a, &b) + }) +} + +func TestNumericSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result *pgtype.Numeric + }{ + {source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: float32(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Valid: true}}, + {source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: float64(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Valid: true}}, + {source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: int64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: int8(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}}, + {source: int16(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}}, + {source: int32(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}}, + {source: int64(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}}, + {source: uint8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: uint16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: uint32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: uint64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: "1", result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: _int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Valid: true}}, + {source: float64(1000), result: &pgtype.Numeric{Int: big.NewInt(1), Exp: 3, Valid: true}}, + {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Valid: true}}, + {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Valid: true}}, + {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Valid: true}}, + {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Valid: true, NaN: true}}, + {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Valid: true, NaN: true}}, + {source: pgtype.Infinity, result: &pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: math.Inf(1), result: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}}, + {source: float32(math.Inf(1)), result: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}}, + {source: pgtype.NegativeInfinity, result: &pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, + {source: math.Inf(-1), result: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.NegativeInfinity}}, + {source: float32(math.Inf(1)), result: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}}, + } + + for i, tt := range successfulTests { + r := &pgtype.Numeric{} + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !numericEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericAssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src *pgtype.Numeric + dst interface{} + expected interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &f32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &f64, expected: float64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Valid: true}, dst: &f32, expected: float32(4.2)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Valid: true}, dst: &f64, expected: float64(4.2)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &i16, expected: int16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &i32, expected: int32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &i64, expected: int64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: 3, Valid: true}, dst: &i64, expected: int64(42000)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &i, expected: int(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &ui, expected: uint(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(0)}, dst: &pi8, expected: ((*int8)(nil))}, + {src: &pgtype.Numeric{Int: big.NewInt(0)}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Valid: true}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27 + {src: &pgtype.Numeric{Valid: true, NaN: true}, dst: &f64, expected: math.NaN()}, + {src: &pgtype.Numeric{Valid: true, NaN: true}, dst: &f32, expected: float32(math.NaN())}, + {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}, dst: &f64, expected: math.Inf(1)}, + {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}, dst: &f32, expected: float32(math.Inf(1))}, + {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.NegativeInfinity}, dst: &f64, expected: math.Inf(-1)}, + {src: &pgtype.Numeric{Valid: true, InfinityModifier: pgtype.NegativeInfinity}, dst: &f32, expected: float32(math.Inf(-1))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + dst := reflect.ValueOf(tt.dst).Elem().Interface() + switch dstTyped := dst.(type) { + case float32: + nanExpected := math.IsNaN(float64(tt.expected.(float32))) + if nanExpected && !math.IsNaN(float64(dstTyped)) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } else if !nanExpected && dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + case float64: + nanExpected := math.IsNaN(tt.expected.(float64)) + if nanExpected && !math.IsNaN(dstTyped) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } else if !nanExpected && dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + default: + if dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + } + + pointerAllocTests := []struct { + src *pgtype.Numeric + dst interface{} + expected interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &pf32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Valid: true}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src *pgtype.Numeric + dst interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(150), Valid: true}, dst: &i8}, + {src: &pgtype.Numeric{Int: big.NewInt(40000), Valid: true}, dst: &i16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui8}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui32}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui64}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Valid: true}, dst: &ui}, + {src: &pgtype.Numeric{Int: big.NewInt(0)}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} + +func TestNumericEncodeDecodeBinary(t *testing.T) { + ci := pgtype.NewConnInfo() + tests := []interface{}{ + 123, + 0.000012345, + 1.00002345, + math.NaN(), + float32(math.NaN()), + math.Inf(1), + float32(math.Inf(1)), + math.Inf(-1), + float32(math.Inf(-1)), + } + + for i, tt := range tests { + toString := func(n *pgtype.Numeric) string { + ci := pgtype.NewConnInfo() + text, err := n.EncodeText(ci, nil) + if err != nil { + t.Errorf("%d (EncodeText): %v", i, err) + } + return string(text) + } + numeric := &pgtype.Numeric{} + numeric.Set(tt) + + encoded, err := numeric.EncodeBinary(ci, nil) + if err != nil { + t.Errorf("%d (EncodeBinary): %v", i, err) + } + decoded := &pgtype.Numeric{} + err = decoded.DecodeBinary(ci, encoded) + if err != nil { + t.Errorf("%d (DecodeBinary): %v", i, err) + } + + text0 := toString(numeric) + text1 := toString(decoded) + + if text0 != text1 { + t.Errorf("%d: expected %v to equal to %v, but doesn't", i, text0, text1) + } + } +} + +func TestNumericMarshalJSON(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for i, tt := range []struct { + decString string + }{ + {"NaN"}, + {"0"}, + {"1"}, + {"-1"}, + {"1000000000000000000"}, + {"1234.56789"}, + {"1.56789"}, + {"0.00000000000056789"}, + {"0.00123000"}, + {"123e-3"}, + {"243723409723490243842378942378901237502734019231380123e23790"}, + {"3409823409243892349028349023482934092340892390101e-14021"}, + } { + var num pgtype.Numeric + var pgJSON string + err := conn.QueryRow(context.Background(), `select $1::numeric, to_json($1::numeric)`, tt.decString).Scan(&num, &pgJSON) + require.NoErrorf(t, err, "%d", i) + + goJSON, err := json.Marshal(num) + require.NoErrorf(t, err, "%d", i) + + require.Equal(t, pgJSON, string(goJSON)) + } +} diff --git a/pgtype/numrange.go b/pgtype/numrange.go new file mode 100644 index 00000000..f1118d83 --- /dev/null +++ b/pgtype/numrange.go @@ -0,0 +1,257 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +type Numrange struct { + Lower Numeric + Upper Numeric + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (dst *Numrange) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Numrange{} + return nil + } + + switch value := src.(type) { + case Numrange: + *dst = value + case *Numrange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to Numrange", src) + } + + return nil +} + +func (src Numrange) Get() interface{} { + if !src.Valid { + return nil + } + return src +} + +func (src *Numrange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numrange{} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Numrange{Valid: true} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numrange{} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Numrange{Valid: true} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numrange) Scan(src interface{}) error { + if src == nil { + *dst = Numrange{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Numrange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/numrange_test.go b/pgtype/numrange_test.go new file mode 100644 index 00000000..b9ea7658 --- /dev/null +++ b/pgtype/numrange_test.go @@ -0,0 +1,46 @@ +package pgtype_test + +import ( + "math/big" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestNumrangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "numrange", []interface{}{ + &pgtype.Numrange{ + LowerType: pgtype.Empty, + UpperType: pgtype.Empty, + Valid: true, + }, + &pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Valid: true}, + Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + &pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Valid: true}, + Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + &pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Unbounded, + Valid: true, + }, + &pgtype.Numrange{ + Upper: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Valid: true}, + LowerType: pgtype.Unbounded, + UpperType: pgtype.Exclusive, + Valid: true, + }, + &pgtype.Numrange{}, + }) +} diff --git a/pgtype/oid.go b/pgtype/oid.go new file mode 100644 index 00000000..31677e89 --- /dev/null +++ b/pgtype/oid.go @@ -0,0 +1,81 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + + "github.com/jackc/pgio" +) + +// OID (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-oid.html, used +// internally by PostgreSQL as a primary key for various system tables. It is +// currently implemented as an unsigned four-byte integer. Its definition can be +// found in src/include/postgres_ext.h in the PostgreSQL sources. Because it is +// so frequently required to be in a NOT NULL condition OID cannot be NULL. To +// allow for NULL OIDs use OIDValue. +type OID uint32 + +func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + return fmt.Errorf("cannot decode nil into OID") + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *dst = OID(n) + return nil +} + +func (dst *OID) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + return fmt.Errorf("cannot decode nil into OID") + } + + if len(src) != 4 { + return fmt.Errorf("invalid length: %v", len(src)) + } + + n := binary.BigEndian.Uint32(src) + *dst = OID(n) + return nil +} + +func (src OID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return append(buf, strconv.FormatUint(uint64(src), 10)...), nil +} + +func (src OID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return pgio.AppendUint32(buf, uint32(src)), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *OID) Scan(src interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", src) + } + + switch src := src.(type) { + case int64: + *dst = OID(src) + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src OID) Value() (driver.Value, error) { + return int64(src), nil +} diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go new file mode 100644 index 00000000..5dc9136c --- /dev/null +++ b/pgtype/oid_value.go @@ -0,0 +1,55 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// OIDValue (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-OIDValue.html, used +// internally by PostgreSQL as a primary key for various system tables. It is +// currently implemented as an unsigned four-byte integer. Its definition can be +// found in src/include/postgres_ext.h in the PostgreSQL sources. +type OIDValue pguint32 + +// Set converts from src to dst. Note that as OIDValue is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *OIDValue) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst OIDValue) Get() interface{} { + return (pguint32)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as OIDValue is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *OIDValue) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *OIDValue) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) +} + +func (dst *OIDValue) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) +} + +func (src OIDValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeText(ci, buf) +} + +func (src OIDValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *OIDValue) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src OIDValue) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/pgtype/oid_value_test.go b/pgtype/oid_value_test.go new file mode 100644 index 00000000..021f81d3 --- /dev/null +++ b/pgtype/oid_value_test.go @@ -0,0 +1,95 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestOIDValueTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ + &pgtype.OIDValue{Uint: 42, Valid: true}, + &pgtype.OIDValue{}, + }) +} + +func TestOIDValueSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.OIDValue + }{ + {source: uint32(1), result: pgtype.OIDValue{Uint: 1, Valid: true}}, + } + + for i, tt := range successfulTests { + var r pgtype.OIDValue + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestOIDValueAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.OIDValue + dst interface{} + expected interface{} + }{ + {src: pgtype.OIDValue{Uint: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.OIDValue{}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.OIDValue + dst interface{} + expected interface{} + }{ + {src: pgtype.OIDValue{Uint: 42, Valid: true}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.OIDValue + dst interface{} + }{ + {src: pgtype.OIDValue{}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/path.go b/pgtype/path.go new file mode 100644 index 00000000..7ac38c68 --- /dev/null +++ b/pgtype/path.go @@ -0,0 +1,185 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Path struct { + P []Vec2 + Closed bool + Valid bool +} + +func (dst *Path) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Path", src) +} + +func (dst Path) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Path) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Path{} + return nil + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Path: %v", len(src)) + } + + closed := src[0] == '(' + points := make([]Vec2, 0) + + str := string(src[2:]) + + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + points = append(points, Vec2{x, y}) + + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } + } + + *dst = Path{P: points, Closed: closed, Valid: true} + return nil +} + +func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Path{} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for Path: %v", len(src)) + } + + closed := src[0] == 1 + pointCount := int(binary.BigEndian.Uint32(src[1:])) + + rp := 5 + + if 5+pointCount*16 != len(src) { + return fmt.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + *dst = Path{ + P: points, + Closed: closed, + Valid: true, + } + return nil +} + +func (src Path) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var startByte, endByte byte + if src.Closed { + startByte = '(' + endByte = ')' + } else { + startByte = '[' + endByte = ']' + } + buf = append(buf, startByte) + + for i, p := range src.P { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(p.X, 'f', -1, 64), + strconv.FormatFloat(p.Y, 'f', -1, 64), + )...) + } + + return append(buf, endByte), nil +} + +func (src Path) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var closeByte byte + if src.Closed { + closeByte = 1 + } + buf = append(buf, closeByte) + + buf = pgio.AppendInt32(buf, int32(len(src.P))) + + for _, p := range src.P { + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Path) Scan(src interface{}) error { + if src == nil { + *dst = Path{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Path) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/path_test.go b/pgtype/path_test.go new file mode 100644 index 00000000..9a66996e --- /dev/null +++ b/pgtype/path_test.go @@ -0,0 +1,29 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestPathTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "path", []interface{}{ + &pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, + Closed: false, + Valid: true, + }, + &pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, + Closed: true, + Valid: true, + }, + &pgtype.Path{ + P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Closed: true, + Valid: true, + }, + &pgtype.Path{}, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go new file mode 100644 index 00000000..d8dd5abf --- /dev/null +++ b/pgtype/pgtype.go @@ -0,0 +1,941 @@ +package pgtype + +import ( + "database/sql" + "encoding/binary" + "fmt" + "math" + "net" + "reflect" + "time" +) + +// PostgreSQL oids for common types +const ( + BoolOID = 16 + ByteaOID = 17 + QCharOID = 18 + NameOID = 19 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + TIDOID = 27 + XIDOID = 28 + CIDOID = 29 + JSONOID = 114 + PointOID = 600 + LsegOID = 601 + PathOID = 602 + BoxOID = 603 + PolygonOID = 604 + LineOID = 628 + CIDROID = 650 + CIDRArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + CircleOID = 718 + UnknownOID = 705 + MacaddrOID = 829 + InetOID = 869 + BoolArrayOID = 1000 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + ByteaArrayOID = 1001 + BPCharArrayOID = 1014 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + ACLItemOID = 1033 + ACLItemArrayOID = 1034 + InetArrayOID = 1041 + BPCharOID = 1042 + VarcharOID = 1043 + DateOID = 1082 + TimeOID = 1083 + TimestampOID = 1114 + TimestampArrayOID = 1115 + DateArrayOID = 1182 + TimestamptzOID = 1184 + TimestamptzArrayOID = 1185 + IntervalOID = 1186 + NumericArrayOID = 1231 + BitOID = 1560 + VarbitOID = 1562 + NumericOID = 1700 + RecordOID = 2249 + UUIDOID = 2950 + UUIDArrayOID = 2951 + JSONBOID = 3802 + JSONBArrayOID = 3807 + DaterangeOID = 3912 + Int4rangeOID = 3904 + NumrangeOID = 3906 + TsrangeOID = 3908 + TsrangeArrayOID = 3909 + TstzrangeOID = 3910 + TstzrangeArrayOID = 3911 + Int8rangeOID = 3926 +) + +type InfinityModifier int8 + +const ( + Infinity InfinityModifier = 1 + None InfinityModifier = 0 + NegativeInfinity InfinityModifier = -Infinity +) + +func (im InfinityModifier) String() string { + switch im { + case None: + return "none" + case Infinity: + return "infinity" + case NegativeInfinity: + return "-infinity" + default: + return "invalid" + } +} + +// PostgreSQL format codes +const ( + TextFormatCode = 0 + BinaryFormatCode = 1 +) + +// Value translates values to and from an internal canonical representation for the type. To actually be usable a type +// that implements Value should also implement some combination of BinaryDecoder, BinaryEncoder, TextDecoder, +// and TextEncoder. +// +// Operations that update a Value (e.g. Set, DecodeText, DecodeBinary) should entirely replace the value. e.g. Internal +// slices should be replaced not resized and reused. This allows Get and AssignTo to return a slice directly rather +// than incur a usually unnecessary copy. +type Value interface { + // Set converts and assigns src to itself. Value takes ownership of src. + Set(src interface{}) error + + // Get returns the simplest representation of Value. Get may return a pointer to an internal value but it must never + // mutate that value. e.g. If Get returns a []byte Value must never change the contents of the []byte. + Get() interface{} + + // AssignTo converts and assigns the Value to dst. AssignTo may a pointer to an internal value but it must never + // mutate that value. e.g. If Get returns a []byte Value must never change the contents of the []byte. + AssignTo(dst interface{}) error +} + +// TypeValue is a Value where instances can represent different PostgreSQL types. This can be useful for +// representing types such as enums, composites, and arrays. +// +// In general, instances of TypeValue should not be used to directly represent a value. It should only be used as an +// encoder and decoder internal to ConnInfo. +type TypeValue interface { + Value + + // NewTypeValue creates a TypeValue including references to internal type information. e.g. the list of members + // in an EnumType. + NewTypeValue() Value + + // TypeName returns the PostgreSQL name of this type. + TypeName() string +} + +// ValueTranscoder is a value that implements the text and binary encoding and decoding interfaces. +type ValueTranscoder interface { + Value + FormatSupport + ParamEncoder + ResultDecoder +} + +type FormatSupport interface { + BinaryFormatSupported() bool + TextFormatSupported() bool + PreferredFormat() int16 +} + +type ParamEncoder interface { + // EncodeParam should append the encoded value of self to buf. If self is the + // SQL value NULL then append nothing and return (nil, nil). The caller of + // EncodeText is responsible for writing the correct NULL value or the + // length of the data written. + EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) +} + +type ResultDecoder interface { + // DecodeResult decodes src into ResultDecoder. If src is nil then the + // original SQL value is NULL. ResultDecoder takes ownership of src. The + // caller MUST not use it again. + DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error +} + +// ResultFormatPreferrer allows a type to specify its preferred result format instead of it being inferred from +// whether it is also a BinaryDecoder. +type ResultFormatPreferrer interface { + PreferredResultFormat() int16 +} + +type BinaryDecoder interface { + // DecodeBinary decodes src into BinaryDecoder. If src is nil then the + // original SQL value is NULL. BinaryDecoder takes ownership of src. The + // caller MUST not use it again. + DecodeBinary(ci *ConnInfo, src []byte) error +} + +type TextDecoder interface { + // DecodeText decodes src into TextDecoder. If src is nil then the original + // SQL value is NULL. TextDecoder takes ownership of src. The caller MUST not + // use it again. + DecodeText(ci *ConnInfo, src []byte) error +} + +// BinaryEncoder is implemented by types that can encode themselves into the +// PostgreSQL binary wire format. +type BinaryEncoder interface { + // EncodeBinary should append the binary format of self to buf. If self is the + // SQL value NULL then append nothing and return (nil, nil). The caller of + // EncodeBinary is responsible for writing the correct NULL value or the + // length of the data written. + EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) +} + +// TextEncoder is implemented by types that can encode themselves into the +// PostgreSQL text wire format. +type TextEncoder interface { + // EncodeText should append the text format of self to buf. If self is the + // SQL value NULL then append nothing and return (nil, nil). The caller of + // EncodeText is responsible for writing the correct NULL value or the + // length of the data written. + EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) +} + +type nullAssignmentError struct { + dst interface{} +} + +func (e *nullAssignmentError) Error() string { + return fmt.Sprintf("cannot assign NULL to %T", e.dst) +} + +type DataType struct { + Value Value + + resultDecoder ResultDecoder + + textDecoder TextDecoder + binaryDecoder BinaryDecoder + + Name string + OID uint32 +} + +type ConnInfo struct { + oidToDataType map[uint32]*DataType + nameToDataType map[string]*DataType + reflectTypeToName map[reflect.Type]string + oidToFormatCode map[uint32]int16 + oidToResultFormatCode map[uint32]int16 + + reflectTypeToDataType map[reflect.Type]*DataType + + preferAssignToOverSQLScannerTypes map[reflect.Type]struct{} +} + +func newConnInfo() *ConnInfo { + return &ConnInfo{ + oidToDataType: make(map[uint32]*DataType), + nameToDataType: make(map[string]*DataType), + reflectTypeToName: make(map[reflect.Type]string), + oidToFormatCode: make(map[uint32]int16), + oidToResultFormatCode: make(map[uint32]int16), + preferAssignToOverSQLScannerTypes: make(map[reflect.Type]struct{}), + } +} + +func NewConnInfo() *ConnInfo { + ci := newConnInfo() + + ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID}) + ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID}) + ci.RegisterDataType(DataType{Value: &BPCharArray{}, Name: "_bpchar", OID: BPCharArrayOID}) + ci.RegisterDataType(DataType{Value: &ByteaArray{}, Name: "_bytea", OID: ByteaArrayOID}) + ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID}) + ci.RegisterDataType(DataType{Value: &DateArray{}, Name: "_date", OID: DateArrayOID}) + ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) + ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) + ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID}) + ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID}) + ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID}) + ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID}) + ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) + ci.RegisterDataType(DataType{Value: &TextArray{}, Name: "_text", OID: TextArrayOID}) + ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) + ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID}) + ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) + ci.RegisterDataType(DataType{Value: &VarcharArray{}, Name: "_varchar", OID: VarcharArrayOID}) + ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID}) + ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID}) + ci.RegisterDataType(DataType{Value: &Bool{}, Name: "bool", OID: BoolOID}) + ci.RegisterDataType(DataType{Value: &Box{}, Name: "box", OID: BoxOID}) + ci.RegisterDataType(DataType{Value: &BPChar{}, Name: "bpchar", OID: BPCharOID}) + ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID}) + ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) + ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID}) + ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) + ci.RegisterDataType(DataType{Value: &Circle{}, Name: "circle", OID: CircleOID}) + ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID}) + ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) + ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) + ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) + ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID}) + ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID}) + ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID}) + ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) + ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID}) + ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) + ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID}) + ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) + ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) + ci.RegisterDataType(DataType{Value: &JSONBArray{}, Name: "_jsonb", OID: JSONBArrayOID}) + ci.RegisterDataType(DataType{Value: &Line{}, Name: "line", OID: LineOID}) + ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID}) + ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) + ci.RegisterDataType(DataType{Value: &Name{}, Name: "name", OID: NameOID}) + ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) + ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) + ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID}) + ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID}) + ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID}) + ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID}) + ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) + ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID}) + ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID}) + ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID}) + ci.RegisterDataType(DataType{Value: &Timestamp{}, Name: "timestamp", OID: TimestampOID}) + ci.RegisterDataType(DataType{Value: &Timestamptz{}, Name: "timestamptz", OID: TimestamptzOID}) + ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) + ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) + ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) + ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) + ci.RegisterDataType(DataType{Value: &Unknown{}, Name: "unknown", OID: UnknownOID}) + ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID}) + ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID}) + ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID}) + ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID}) + + registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { + ci.RegisterDefaultPgType(value, name) + valueType := reflect.TypeOf(value) + + ci.RegisterDefaultPgType(reflect.New(valueType).Interface(), name) + + sliceType := reflect.SliceOf(valueType) + ci.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName) + + ci.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName) + } + + // Integer types that directly map to a PostgreSQL type + registerDefaultPgTypeVariants("int2", "_int2", int16(0)) + registerDefaultPgTypeVariants("int4", "_int4", int32(0)) + registerDefaultPgTypeVariants("int8", "_int8", int64(0)) + + // Integer types that do not have a direct match to a PostgreSQL type + registerDefaultPgTypeVariants("int8", "_int8", uint16(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint32(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint64(0)) + registerDefaultPgTypeVariants("int8", "_int8", int(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint(0)) + + registerDefaultPgTypeVariants("float4", "_float4", float32(0)) + registerDefaultPgTypeVariants("float8", "_float8", float64(0)) + + registerDefaultPgTypeVariants("bool", "_bool", false) + registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{}) + registerDefaultPgTypeVariants("text", "_text", "") + registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil)) + + registerDefaultPgTypeVariants("inet", "_inet", net.IP{}) + ci.RegisterDefaultPgType((*net.IPNet)(nil), "cidr") + ci.RegisterDefaultPgType([]*net.IPNet(nil), "_cidr") + + return ci +} + +func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) { + for name, oid := range nameOIDs { + var value Value + if t, ok := nameValues[name]; ok { + value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value) + } else { + value = &GenericText{} + } + ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid}) + } +} + +func (ci *ConnInfo) RegisterDataType(t DataType) { + t.Value = NewValue(t.Value) + + ci.oidToDataType[t.OID] = &t + ci.nameToDataType[t.Name] = &t + + { + var formatCode int16 + if pfp, ok := t.Value.(FormatSupport); ok { + formatCode = pfp.PreferredFormat() + } else if _, ok := t.Value.(BinaryEncoder); ok { + formatCode = BinaryFormatCode + } + ci.oidToFormatCode[t.OID] = formatCode + } + + if d, ok := t.Value.(ResultDecoder); ok { + t.resultDecoder = d + } + + if d, ok := t.Value.(TextDecoder); ok { + t.textDecoder = d + } + + if d, ok := t.Value.(BinaryDecoder); ok { + t.binaryDecoder = d + } + + ci.reflectTypeToDataType = nil // Invalidated by type registration +} + +// RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be +// encoded or decoded is determined by the PostgreSQL OID. But if the OID of a value to be encoded or decoded is +// unknown, this additional mapping will be used by DataTypeForValue to determine a suitable data type. +func (ci *ConnInfo) RegisterDefaultPgType(value interface{}, name string) { + ci.reflectTypeToName[reflect.TypeOf(value)] = name + ci.reflectTypeToDataType = nil // Invalidated by registering a default type +} + +func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) { + dt, ok := ci.oidToDataType[oid] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { + dt, ok := ci.nameToDataType[name] + return dt, ok +} + +func (ci *ConnInfo) buildReflectTypeToDataType() { + ci.reflectTypeToDataType = make(map[reflect.Type]*DataType) + + for _, dt := range ci.oidToDataType { + if _, is := dt.Value.(TypeValue); !is { + ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt + } + } + + for reflectType, name := range ci.reflectTypeToName { + if dt, ok := ci.nameToDataType[name]; ok { + ci.reflectTypeToDataType[reflectType] = dt + } + } +} + +// DataTypeForValue finds a data type suitable for v. Use RegisterDataType to register types that can encode and decode +// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. +func (ci *ConnInfo) DataTypeForValue(v interface{}) (*DataType, bool) { + if ci.reflectTypeToDataType == nil { + ci.buildReflectTypeToDataType() + } + + if tv, ok := v.(TypeValue); ok { + dt, ok := ci.nameToDataType[tv.TypeName()] + return dt, ok + } + + dt, ok := ci.reflectTypeToDataType[reflect.TypeOf(v)] + return dt, ok +} + +func (ci *ConnInfo) FormatCodeForOID(oid uint32) int16 { + fc, ok := ci.oidToFormatCode[oid] + if ok { + return fc + } + return TextFormatCode +} + +// PreferAssignToOverSQLScannerForType makes a sql.Scanner type use the AssignTo scan path instead of sql.Scanner. +// This is primarily for efficient integration with 3rd party numeric and UUID types. +func (ci *ConnInfo) PreferAssignToOverSQLScannerForType(value interface{}) { + ci.preferAssignToOverSQLScannerTypes[reflect.TypeOf(value)] = struct{}{} +} + +// ScanPlan is a precompiled plan to scan into a type of destination. +type ScanPlan interface { + // Scan scans src into dst. If the dst type has changed in an incompatible way a ScanPlan should automatically + // replan and scan. + Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error +} + +type scanPlanDstResultDecoder struct{} + +func (scanPlanDstResultDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if d, ok := (dst).(ResultDecoder); ok { + return d.DecodeResult(ci, oid, formatCode, src) + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanDstBinaryDecoder struct{} + +func (scanPlanDstBinaryDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if d, ok := (dst).(BinaryDecoder); ok { + return d.DecodeBinary(ci, src) + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanDstTextDecoder struct{} + +func (plan scanPlanDstTextDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if d, ok := (dst).(TextDecoder); ok { + return d.DecodeText(ci, src) + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanDataTypeSQLScanner DataType + +func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner, ok := dst.(sql.Scanner) + if !ok { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + + dt := (*DataType)(plan) + var err error + switch formatCode { + case BinaryFormatCode: + err = dt.binaryDecoder.DecodeBinary(ci, src) + case TextFormatCode: + err = dt.textDecoder.DecodeText(ci, src) + } + if err != nil { + return err + } + + sqlSrc, err := DatabaseSQLValue(ci, dt.Value) + if err != nil { + return err + } + return scanner.Scan(sqlSrc) +} + +type scanPlanDataTypeAssignTo DataType + +func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + dt := (*DataType)(plan) + var err error + if dt.resultDecoder != nil { + err = dt.resultDecoder.DecodeResult(ci, oid, formatCode, src) + } else { + switch formatCode { + case BinaryFormatCode: + err = dt.binaryDecoder.DecodeBinary(ci, src) + case TextFormatCode: + err = dt.textDecoder.DecodeText(ci, src) + } + } + if err != nil { + return err + } + + assignToErr := dt.Value.AssignTo(dst) + if assignToErr == nil { + return nil + } + + if dstPtr, ok := dst.(*interface{}); ok { + *dstPtr = dt.Value.Get() + return nil + } + + // assignToErr might have failed because the type of destination has changed + newPlan := ci.PlanScan(oid, formatCode, dst) + if newPlan, sameType := newPlan.(*scanPlanDataTypeAssignTo); !sameType { + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + + return assignToErr +} + +type scanPlanSQLScanner struct{} + +func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := dst.(sql.Scanner) + if src == nil { + // This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the + // text format path would be converted to empty string. + return scanner.Scan(nil) + } else if formatCode == BinaryFormatCode { + return scanner.Scan(src) + } else { + return scanner.Scan(string(src)) + } +} + +type scanPlanReflection struct{} + +func (scanPlanReflection) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + // We might be given a pointer to something that implements the decoder interface(s), + // even though the pointer itself doesn't. + refVal := reflect.ValueOf(dst) + if refVal.Kind() == reflect.Ptr && refVal.Type().Elem().Kind() == reflect.Ptr { + // If the database returned NULL, then we set dest as nil to indicate that. + if src == nil { + nilPtr := reflect.Zero(refVal.Type().Elem()) + refVal.Elem().Set(nilPtr) + return nil + } + + // We need to allocate an element, and set the destination to it + // Then we can retry as that element. + elemPtr := reflect.New(refVal.Type().Elem().Elem()) + refVal.Elem().Set(elemPtr) + + plan := ci.PlanScan(oid, formatCode, elemPtr.Interface()) + return plan.Scan(ci, oid, formatCode, src, elemPtr.Interface()) + } + + return scanUnknownType(oid, formatCode, src, dst) +} + +type scanPlanBinaryInt16 struct{} + +func (scanPlanBinaryInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + if p, ok := (dst).(*int16); ok { + *p = int16(binary.BigEndian.Uint16(src)) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanBinaryInt32 struct{} + +func (scanPlanBinaryInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + if p, ok := (dst).(*int32); ok { + *p = int32(binary.BigEndian.Uint32(src)) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanBinaryInt64 struct{} + +func (scanPlanBinaryInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + if p, ok := (dst).(*int64); ok { + *p = int64(binary.BigEndian.Uint64(src)) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanBinaryFloat32 struct{} + +func (scanPlanBinaryFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + if p, ok := (dst).(*float32); ok { + n := int32(binary.BigEndian.Uint32(src)) + *p = float32(math.Float32frombits(uint32(n))) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanBinaryFloat64 struct{} + +func (scanPlanBinaryFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + if p, ok := (dst).(*float64); ok { + n := int64(binary.BigEndian.Uint64(src)) + *p = float64(math.Float64frombits(uint64(n))) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanBinaryBytes struct{} + +func (scanPlanBinaryBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if p, ok := (dst).(*[]byte); ok { + *p = src + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanString struct{} + +func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if p, ok := (dst).(*string); ok { + *p = string(src) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +// PlanScan prepares a plan to scan a value into dst. +func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { + switch formatCode { + case BinaryFormatCode: + switch dst.(type) { + case *string: + switch oid { + case TextOID, VarcharOID: + return scanPlanString{} + } + case *int16: + if oid == Int2OID { + return scanPlanBinaryInt16{} + } + case *int32: + if oid == Int4OID { + return scanPlanBinaryInt32{} + } + case *int64: + if oid == Int8OID { + return scanPlanBinaryInt64{} + } + case *float32: + if oid == Float4OID { + return scanPlanBinaryFloat32{} + } + case *float64: + if oid == Float8OID { + return scanPlanBinaryFloat64{} + } + case *[]byte: + switch oid { + case ByteaOID, TextOID, VarcharOID, JSONOID: + return scanPlanBinaryBytes{} + } + case BinaryDecoder: + return scanPlanDstBinaryDecoder{} + } + case TextFormatCode: + switch dst.(type) { + case *string: + return scanPlanString{} + case *[]byte: + if oid != ByteaOID { + return scanPlanBinaryBytes{} + } + case TextDecoder: + return scanPlanDstTextDecoder{} + } + } + + var dt *DataType + + if oid == 0 { + if dataType, ok := ci.DataTypeForValue(dst); ok { + dt = dataType + } + } else { + if dataType, ok := ci.DataTypeForOID(oid); ok { + dt = dataType + } + } + + if dt != nil { + if _, ok := dst.(sql.Scanner); ok { + if _, found := ci.preferAssignToOverSQLScannerTypes[reflect.TypeOf(dst)]; !found { + return (*scanPlanDataTypeSQLScanner)(dt) + } + } + return (*scanPlanDataTypeAssignTo)(dt) + } + + if _, ok := dst.(sql.Scanner); ok { + return scanPlanSQLScanner{} + } + + return scanPlanReflection{} +} + +func (ci *ConnInfo) Scan(oid uint32, formatCode int16, src []byte, dst interface{}) error { + if dst == nil { + return nil + } + + plan := ci.PlanScan(oid, formatCode, dst) + return plan.Scan(ci, oid, formatCode, src, dst) +} + +func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) error { + switch dest := dest.(type) { + case *string: + if formatCode == BinaryFormatCode { + return fmt.Errorf("unknown oid %d in binary format cannot be scanned into %T", oid, dest) + } + *dest = string(buf) + return nil + case *[]byte: + *dest = buf + return nil + default: + if nextDst, retry := GetAssignToDstType(dest); retry { + return scanUnknownType(oid, formatCode, buf, nextDst) + } + return fmt.Errorf("unknown oid %d cannot be scanned into %T", oid, dest) + } +} + +// NewValue returns a new instance of the same type as v. +func NewValue(v Value) Value { + if tv, ok := v.(TypeValue); ok { + return tv.NewTypeValue() + } else { + return reflect.New(reflect.ValueOf(v).Elem().Type()).Interface().(Value) + } +} + +var nameValues map[string]Value + +func init() { + nameValues = map[string]Value{ + "_aclitem": &ACLItemArray{}, + "_bool": &BoolArray{}, + "_bpchar": &BPCharArray{}, + "_bytea": &ByteaArray{}, + "_cidr": &CIDRArray{}, + "_date": &DateArray{}, + "_float4": &Float4Array{}, + "_float8": &Float8Array{}, + "_inet": &InetArray{}, + "_int2": &Int2Array{}, + "_int4": &Int4Array{}, + "_int8": &Int8Array{}, + "_numeric": &NumericArray{}, + "_text": &TextArray{}, + "_timestamp": &TimestampArray{}, + "_timestamptz": &TimestamptzArray{}, + "_uuid": &UUIDArray{}, + "_varchar": &VarcharArray{}, + "_jsonb": &JSONBArray{}, + "aclitem": &ACLItem{}, + "bit": &Bit{}, + "bool": &Bool{}, + "box": &Box{}, + "bpchar": &BPChar{}, + "bytea": &Bytea{}, + "char": &QChar{}, + "cid": &CID{}, + "cidr": &CIDR{}, + "circle": &Circle{}, + "date": &Date{}, + "daterange": &Daterange{}, + "float4": &Float4{}, + "float8": &Float8{}, + "hstore": &Hstore{}, + "inet": &Inet{}, + "int2": &Int2{}, + "int4": &Int4{}, + "int4range": &Int4range{}, + "int8": &Int8{}, + "int8range": &Int8range{}, + "interval": &Interval{}, + "json": &JSON{}, + "jsonb": &JSONB{}, + "line": &Line{}, + "lseg": &Lseg{}, + "macaddr": &Macaddr{}, + "name": &Name{}, + "numeric": &Numeric{}, + "numrange": &Numrange{}, + "oid": &OIDValue{}, + "path": &Path{}, + "point": &Point{}, + "polygon": &Polygon{}, + "record": &Record{}, + "text": &Text{}, + "tid": &TID{}, + "timestamp": &Timestamp{}, + "timestamptz": &Timestamptz{}, + "tsrange": &Tsrange{}, + "_tsrange": &TsrangeArray{}, + "tstzrange": &Tstzrange{}, + "_tstzrange": &TstzrangeArray{}, + "unknown": &Unknown{}, + "uuid": &UUID{}, + "varbit": &Varbit{}, + "varchar": &Varchar{}, + "xid": &XID{}, + } +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go new file mode 100644 index 00000000..9bf1f242 --- /dev/null +++ b/pgtype/pgtype_test.go @@ -0,0 +1,301 @@ +package pgtype_test + +import ( + "bytes" + "database/sql" + "errors" + "net" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" + _ "github.com/jackc/pgx/v4/stdlib" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test for renamed types +type _string string +type _bool bool +type _int8 int8 +type _int16 int16 +type _int16Slice []int16 +type _int32Slice []int32 +type _int64Slice []int64 +type _float32Slice []float32 +type _float64Slice []float64 +type _byteSlice []byte + +func mustParseCIDR(t testing.TB, s string) *net.IPNet { + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + + return ipnet +} + +func mustParseInet(t testing.TB, s string) *net.IPNet { + ip, ipnet, err := net.ParseCIDR(s) + if err == nil { + if ipv4 := ip.To4(); ipv4 != nil { + ipnet.IP = ipv4 + } + return ipnet + } + + // May be bare IP address. + // + ip = net.ParseIP(s) + if ip == nil { + t.Fatal(errors.New("unable to parse inet address")) + } + ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + if ipv4 := ip.To4(); ipv4 != nil { + ipnet.IP = ipv4 + ipnet.Mask = net.CIDRMask(32, 32) + } + return ipnet +} + +func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { + addr, err := net.ParseMAC(s) + if err != nil { + t.Fatal(err) + } + + return addr +} + +func TestConnInfoFormatCodeForOID(t *testing.T) { + ci := pgtype.NewConnInfo() + + // pgtype.JSONB implements BinaryEncoder but also implements ParamFormatPreferrer to override it to text. + assert.Equal(t, int16(pgtype.TextFormatCode), ci.FormatCodeForOID(pgtype.JSONBOID)) + + // pgtype.Int4 implements BinaryEncoder but does not implement ParamFormatPreferrer so it should be binary. + assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.FormatCodeForOID(pgtype.Int4OID)) +} + +func TestConnInfoScanNilIsNoOp(t *testing.T) { + ci := pgtype.NewConnInfo() + + err := ci.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), nil) + assert.NoError(t, err) +} + +func TestConnInfoScanTextFormatInterfacePtr(t *testing.T) { + ci := pgtype.NewConnInfo() + var got interface{} + err := ci.Scan(pgtype.TextOID, pgx.TextFormatCode, []byte("foo"), &got) + require.NoError(t, err) + assert.Equal(t, "foo", got) +} + +func TestConnInfoScanTextFormatNonByteaIntoByteSlice(t *testing.T) { + ci := pgtype.NewConnInfo() + var got []byte + err := ci.Scan(pgtype.JSONBOID, pgx.TextFormatCode, []byte("{}"), &got) + require.NoError(t, err) + assert.Equal(t, []byte("{}"), got) +} + +func TestConnInfoScanBinaryFormatInterfacePtr(t *testing.T) { + ci := pgtype.NewConnInfo() + var got interface{} + err := ci.Scan(pgtype.TextOID, pgx.BinaryFormatCode, []byte("foo"), &got) + require.NoError(t, err) + assert.Equal(t, "foo", got) +} + +func TestConnInfoScanUnknownOIDToStringsAndBytes(t *testing.T) { + unknownOID := uint32(999999) + srcBuf := []byte("foo") + ci := pgtype.NewConnInfo() + + var s string + err := ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &s) + assert.NoError(t, err) + assert.Equal(t, "foo", s) + + var rs _string + err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rs) + assert.NoError(t, err) + assert.Equal(t, "foo", string(rs)) + + var b []byte + err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &b) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), b) + + err = ci.Scan(unknownOID, pgx.BinaryFormatCode, srcBuf, &b) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), b) + + var rb _byteSlice + err = ci.Scan(unknownOID, pgx.TextFormatCode, srcBuf, &rb) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), []byte(rb)) + + err = ci.Scan(unknownOID, pgx.BinaryFormatCode, srcBuf, &b) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), []byte(rb)) +} + +type pgCustomType struct { + a string + b string +} + +func (ct *pgCustomType) DecodeText(ci *pgtype.ConnInfo, buf []byte) error { + // This is not a complete parser for the text format of composite types. This is just for test purposes. + if buf == nil { + return errors.New("cannot parse null") + } + + if len(buf) < 2 { + return errors.New("invalid text format") + } + + parts := bytes.Split(buf[1:len(buf)-1], []byte(",")) + if len(parts) != 2 { + return errors.New("wrong number of parts") + } + + ct.a = string(parts[0]) + ct.b = string(parts[1]) + + return nil +} + +func TestConnInfoScanUnregisteredOIDToCustomType(t *testing.T) { + unregisteredOID := uint32(999999) + ci := pgtype.NewConnInfo() + + var ct pgCustomType + err := ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &ct) + assert.NoError(t, err) + assert.Equal(t, "foo", ct.a) + assert.Equal(t, "bar", ct.b) + + // Scan value into pointer to custom type + var pCt *pgCustomType + err = ci.Scan(unregisteredOID, pgx.TextFormatCode, []byte("(foo,bar)"), &pCt) + assert.NoError(t, err) + require.NotNil(t, pCt) + assert.Equal(t, "foo", pCt.a) + assert.Equal(t, "bar", pCt.b) + + // Scan null into pointer to custom type + err = ci.Scan(unregisteredOID, pgx.TextFormatCode, nil, &pCt) + assert.NoError(t, err) + assert.Nil(t, pCt) +} + +func TestConnInfoScanUnknownOIDTextFormat(t *testing.T) { + ci := pgtype.NewConnInfo() + + var n int32 + err := ci.Scan(0, pgx.TextFormatCode, []byte("123"), &n) + assert.NoError(t, err) + assert.EqualValues(t, 123, n) +} + +func TestConnInfoScanUnknownOIDIntoSQLScanner(t *testing.T) { + ci := pgtype.NewConnInfo() + + var s sql.NullString + err := ci.Scan(0, pgx.TextFormatCode, []byte(nil), &s) + assert.NoError(t, err) + assert.Equal(t, "", s.String) + assert.False(t, s.Valid) +} + +func BenchmarkConnInfoScanInt4IntoBinaryDecoder(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v pgtype.Int4 + + for i := 0; i < b.N; i++ { + v = pgtype.Int4{} + err := ci.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != (pgtype.Int4{Int: 42, Valid: true}) { + b.Fatal("scan failed due to bad value") + } + } +} + +func TestScanPlanBinaryInt32ScanChangedType(t *testing.T) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v int32 + + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) + err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + require.NoError(t, err) + require.EqualValues(t, 42, v) + + var d pgtype.Int4 + err = plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &d) + require.NoError(t, err) + require.EqualValues(t, 42, d.Int) + require.True(t, d.Valid) +} + +func BenchmarkConnInfoScanInt4IntoGoInt32(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v int32 + + for i := 0; i < b.N; i++ { + v = 0 + err := ci.Scan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != 42 { + b.Fatal("scan failed due to bad value") + } + } +} + +func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v pgtype.Int4 + + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) + + for i := 0; i < b.N; i++ { + v = pgtype.Int4{} + err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != (pgtype.Int4{Int: 42, Valid: true}) { + b.Fatal("scan failed due to bad value") + } + } +} + +func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v int32 + + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, &v) + + for i := 0; i < b.N; i++ { + v = 0 + err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != 42 { + b.Fatal("scan failed due to bad value") + } + } +} diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go new file mode 100644 index 00000000..e36ebb1f --- /dev/null +++ b/pgtype/pguint32.go @@ -0,0 +1,148 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +// pguint32 is the core type that is used to implement PostgreSQL types such as +// CID and XID. +type pguint32 struct { + Uint uint32 + Valid bool +} + +// Set converts from src to dst. Note that as pguint32 is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *pguint32) Set(src interface{}) error { + switch value := src.(type) { + case int64: + if value < 0 { + return fmt.Errorf("%d is less than minimum value for pguint32", value) + } + if value > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for pguint32", value) + } + *dst = pguint32{Uint: uint32(value), Valid: true} + case uint32: + *dst = pguint32{Uint: value, Valid: true} + default: + return fmt.Errorf("cannot convert %v to pguint32", value) + } + + return nil +} + +func (dst pguint32) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.Uint +} + +// AssignTo assigns from src to dst. Note that as pguint32 is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *pguint32) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *uint32: + if src.Valid { + *v = src.Uint + } else { + return fmt.Errorf("cannot assign %v into %T", src, dst) + } + case **uint32: + if src.Valid { + n := src.Uint + *v = &n + } else { + *v = nil + } + } + + return nil +} + +func (dst *pguint32) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = pguint32{} + return nil + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *dst = pguint32{Uint: uint32(n), Valid: true} + return nil +} + +func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = pguint32{} + return nil + } + + if len(src) != 4 { + return fmt.Errorf("invalid length: %v", len(src)) + } + + n := binary.BigEndian.Uint32(src) + *dst = pguint32{Uint: n, Valid: true} + return nil +} + +func (src pguint32) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, strconv.FormatUint(uint64(src.Uint), 10)...), nil +} + +func (src pguint32) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return pgio.AppendUint32(buf, src.Uint), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *pguint32) Scan(src interface{}) error { + if src == nil { + *dst = pguint32{} + return nil + } + + switch src := src.(type) { + case uint32: + *dst = pguint32{Uint: src, Valid: true} + return nil + case int64: + *dst = pguint32{Uint: uint32(src), Valid: true} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src pguint32) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return int64(src.Uint), nil +} diff --git a/pgtype/pgxtype/README.md b/pgtype/pgxtype/README.md new file mode 100644 index 00000000..a070111f --- /dev/null +++ b/pgtype/pgxtype/README.md @@ -0,0 +1,3 @@ +# pgxtype + +pgxtype is a helper module that connects pgx and pgtype. This package is not currently covered by semantic version guarantees. i.e. The interfaces may change without a major version release of pgtype. diff --git a/pgtype/pgxtype/pgxtype.go b/pgtype/pgxtype/pgxtype.go new file mode 100644 index 00000000..041f2545 --- /dev/null +++ b/pgtype/pgxtype/pgxtype.go @@ -0,0 +1,145 @@ +package pgxtype + +import ( + "context" + "errors" + + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" +) + +type Querier interface { + Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) + Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row +} + +// LoadDataType uses conn to inspect the database for typeName and produces a pgtype.DataType suitable for +// registration on ci. +func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeName string) (pgtype.DataType, error) { + var oid uint32 + + err := conn.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid) + if err != nil { + return pgtype.DataType{}, err + } + + var typtype string + + err = conn.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype) + if err != nil { + return pgtype.DataType{}, err + } + + switch typtype { + case "b": // array + elementOID, err := GetArrayElementOID(ctx, conn, oid) + if err != nil { + return pgtype.DataType{}, err + } + + var element pgtype.ValueTranscoder + if dt, ok := ci.DataTypeForOID(elementOID); ok { + if element, ok = dt.Value.(pgtype.ValueTranscoder); !ok { + return pgtype.DataType{}, errors.New("array element OID not registered as ValueTranscoder") + } + } + + newElement := func() pgtype.ValueTranscoder { + return pgtype.NewValue(element).(pgtype.ValueTranscoder) + } + + at := pgtype.NewArrayType(typeName, elementOID, newElement) + return pgtype.DataType{Value: at, Name: typeName, OID: oid}, nil + case "c": // composite + fields, err := GetCompositeFields(ctx, conn, oid) + if err != nil { + return pgtype.DataType{}, err + } + ct, err := pgtype.NewCompositeType(typeName, fields, ci) + if err != nil { + return pgtype.DataType{}, err + } + return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil + case "e": // enum + members, err := GetEnumMembers(ctx, conn, oid) + if err != nil { + return pgtype.DataType{}, err + } + return pgtype.DataType{Value: pgtype.NewEnumType(typeName, members), Name: typeName, OID: oid}, nil + default: + return pgtype.DataType{}, errors.New("unknown typtype") + } +} + +func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, error) { + var typelem uint32 + + err := conn.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem) + if err != nil { + return 0, err + } + + return typelem, nil +} + +// GetCompositeFields gets the fields of a composite type. +func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) { + var typrelid uint32 + + err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid) + if err != nil { + return nil, err + } + + var fields []pgtype.CompositeTypeField + + rows, err := conn.Query(ctx, `select attname, atttypid +from pg_attribute +where attrelid=$1 +order by attnum`, typrelid) + if err != nil { + return nil, err + } + + for rows.Next() { + var f pgtype.CompositeTypeField + err := rows.Scan(&f.Name, &f.OID) + if err != nil { + return nil, err + } + fields = append(fields, f) + } + + if rows.Err() != nil { + return nil, rows.Err() + } + + return fields, nil +} + +// GetEnumMembers gets the possible values of the enum by oid. +func GetEnumMembers(ctx context.Context, conn Querier, oid uint32) ([]string, error) { + members := []string{} + + rows, err := conn.Query(ctx, "select enumlabel from pg_enum where enumtypid=$1 order by enumsortorder", oid) + if err != nil { + return nil, err + } + + for rows.Next() { + var m string + err := rows.Scan(&m) + if err != nil { + return nil, err + } + members = append(members, m) + } + + if rows.Err() != nil { + return nil, rows.Err() + } + + return members, nil +} diff --git a/pgtype/point.go b/pgtype/point.go new file mode 100644 index 00000000..d35dbf03 --- /dev/null +++ b/pgtype/point.go @@ -0,0 +1,200 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Vec2 struct { + X float64 + Y float64 +} + +type Point struct { + P Vec2 + Valid bool +} + +func (dst *Point) Set(src interface{}) error { + if src == nil { + dst.Valid = false + return nil + } + err := fmt.Errorf("cannot convert %v to Point", src) + var p *Point + switch value := src.(type) { + case string: + p, err = parsePoint([]byte(value)) + case []byte: + p, err = parsePoint(value) + default: + return err + } + if err != nil { + return err + } + *dst = *p + return nil +} + +func parsePoint(src []byte) (*Point, error) { + if src == nil || bytes.Compare(src, []byte("null")) == 0 { + return &Point{}, nil + } + + if len(src) < 5 { + return nil, fmt.Errorf("invalid length for point: %v", len(src)) + } + if src[0] == '"' && src[len(src)-1] == '"' { + src = src[1 : len(src)-1] + } + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return nil, fmt.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return nil, err + } + + y, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return nil, err + } + + return &Point{P: Vec2{x, y}, Valid: true}, nil +} + +func (dst Point) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Point) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + y, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + *dst = Point{P: Vec2{x, y}, Valid: true} + return nil +} + +func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + + *dst = Point{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + Valid: true, + } + return nil +} + +func (src Point) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(src.P.X, 'f', -1, 64), + strconv.FormatFloat(src.P.Y, 'f', -1, 64), + )...), nil +} + +func (src Point) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Point) Scan(src interface{}) error { + if src == nil { + *dst = Point{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Point) Value() (driver.Value, error) { + return EncodeValueText(src) +} + +func (src Point) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y)) + buff.WriteByte('"') + return buff.Bytes(), nil +} + +func (dst *Point) UnmarshalJSON(point []byte) error { + p, err := parsePoint(point) + if err != nil { + return err + } + *dst = *p + return nil +} diff --git a/pgtype/point_test.go b/pgtype/point_test.go new file mode 100644 index 00000000..82f58e17 --- /dev/null +++ b/pgtype/point_test.go @@ -0,0 +1,135 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestPointTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "point", []interface{}{ + &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789012345}, Valid: true}, + &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Valid: true}, + &pgtype.Point{}, + }) +} + +func TestPoint_Set(t *testing.T) { + tests := []struct { + name string + arg interface{} + valid bool + wantErr bool + }{ + { + name: "first", + arg: "(12312.123123,123123.123123)", + valid: true, + wantErr: false, + }, + { + name: "second", + arg: "(1231s2.123123,123123.123123)", + valid: false, + wantErr: true, + }, + { + name: "third", + arg: []byte("(122.123123,123.123123)"), + valid: true, + wantErr: false, + }, + { + name: "third", + arg: nil, + valid: false, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &pgtype.Point{} + if err := dst.Set(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) + } + if dst.Valid != tt.valid { + t.Errorf("Expected status: %v; got: %v", tt.valid, dst.Valid) + } + }) + } +} + +func TestPoint_MarshalJSON(t *testing.T) { + tests := []struct { + name string + point pgtype.Point + want []byte + }{ + { + name: "second", + point: pgtype.Point{ + P: pgtype.Vec2{X: 12.245, Y: 432.12}, + Valid: true, + }, + want: []byte(`"(12.245,432.12)"`), + }, + { + name: "third", + point: pgtype.Point{ + P: pgtype.Vec2{}, + }, + want: []byte("null"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.point.MarshalJSON() + require.NoError(t, err) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPoint_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + valid bool + arg []byte + wantErr bool + }{ + { + name: "first", + valid: true, + arg: []byte(`"(123.123,54.12)"`), + wantErr: false, + }, + { + name: "second", + valid: false, + arg: []byte(`"(123.123,54.1sad2)"`), + wantErr: true, + }, + { + name: "third", + valid: false, + arg: []byte("null"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &pgtype.Point{} + if err := dst.UnmarshalJSON(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if dst.Valid != tt.valid { + t.Errorf("Valid mismatch: %v != %v", dst.Valid, tt.valid) + } + }) + } +} diff --git a/pgtype/polygon.go b/pgtype/polygon.go new file mode 100644 index 00000000..956920e6 --- /dev/null +++ b/pgtype/polygon.go @@ -0,0 +1,215 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Polygon struct { + P []Vec2 + Valid bool +} + +// Set converts src to dest. +// +// src can be nil, string, []float64, and []pgtype.Vec2. +// +// If src is string the format must be ((x1,y1),(x2,y2),...,(xn,yn)). +// Important that there are no spaces in it. +func (dst *Polygon) Set(src interface{}) error { + if src == nil { + dst.Valid = false + return nil + } + err := fmt.Errorf("cannot convert %v to Polygon", src) + var p *Polygon + switch value := src.(type) { + case string: + p, err = stringToPolygon(value) + case []Vec2: + p = &Polygon{Valid: true, P: value} + err = nil + case []float64: + p, err = float64ToPolygon(value) + default: + return err + } + if err != nil { + return err + } + *dst = *p + return nil +} + +func stringToPolygon(src string) (*Polygon, error) { + p := &Polygon{} + err := p.DecodeText(nil, []byte(src)) + return p, err +} + +func float64ToPolygon(src []float64) (*Polygon, error) { + p := &Polygon{} + if len(src) == 0 { + return p, nil + } + if len(src)%2 != 0 { + return p, fmt.Errorf("invalid length for polygon: %v", len(src)) + } + p.Valid = true + p.P = make([]Vec2, 0) + for i := 0; i < len(src); i += 2 { + p.P = append(p.P, Vec2{X: src[i], Y: src[i+1]}) + } + return p, nil +} + +func (dst Polygon) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Polygon) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Polygon{} + return nil + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Polygon: %v", len(src)) + } + + points := make([]Vec2, 0) + + str := string(src[2:]) + + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + points = append(points, Vec2{x, y}) + + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } + } + + *dst = Polygon{P: points, Valid: true} + return nil +} + +func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Polygon{} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for Polygon: %v", len(src)) + } + + pointCount := int(binary.BigEndian.Uint32(src)) + rp := 4 + + if 4+pointCount*16 != len(src) { + return fmt.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + *dst = Polygon{ + P: points, + Valid: true, + } + return nil +} + +func (src Polygon) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = append(buf, '(') + + for i, p := range src.P { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(p.X, 'f', -1, 64), + strconv.FormatFloat(p.Y, 'f', -1, 64), + )...) + } + + return append(buf, ')'), nil +} + +func (src Polygon) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = pgio.AppendInt32(buf, int32(len(src.P))) + + for _, p := range src.P { + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Polygon) Scan(src interface{}) error { + if src == nil { + *dst = Polygon{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Polygon) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/polygon_test.go b/pgtype/polygon_test.go new file mode 100644 index 00000000..34f8d59a --- /dev/null +++ b/pgtype/polygon_test.go @@ -0,0 +1,89 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestPolygonTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "polygon", []interface{}{ + &pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, + Valid: true, + }, + &pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, + Valid: true, + }, + &pgtype.Polygon{}, + }) +} + +func TestPolygon_Set(t *testing.T) { + tests := []struct { + name string + arg interface{} + valid bool + wantErr bool + }{ + { + name: "string", + arg: "((3.14,1.678901234),(7.1,5.234),(5.0,3.234))", + valid: true, + wantErr: false, + }, { + name: "[]float64", + arg: []float64{1, 2, 3.45, 6.78, 1.23, 4.567, 8.9, 1.0}, + valid: true, + wantErr: false, + }, { + name: "[]Vec2", + arg: []pgtype.Vec2{{1, 2}, {2.3, 4.5}, {6.78, 9.123}}, + valid: true, + wantErr: false, + }, { + name: "null", + arg: nil, + valid: false, + wantErr: false, + }, { + name: "invalid_string_1", + arg: "((3.14,1.678901234),(7.1,5.234),(5.0,3.234x))", + valid: false, + wantErr: true, + }, { + name: "invalid_string_2", + arg: "(3,4)", + valid: false, + wantErr: true, + }, { + name: "invalid_[]float64", + arg: []float64{1, 2, 3.45, 6.78, 1.23, 4.567, 8.9}, + valid: false, + wantErr: true, + }, { + name: "invalid_type", + arg: []int{1, 2, 3, 6}, + valid: false, + wantErr: true, + }, { + name: "empty_[]float64", + arg: []float64{}, + valid: false, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &pgtype.Polygon{} + if err := dst.Set(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) + } + if dst.Valid != tt.valid { + t.Errorf("Expected valid: %v; got: %v", tt.valid, dst.Valid) + } + }) + } +} diff --git a/pgtype/qchar.go b/pgtype/qchar.go new file mode 100644 index 00000000..e56bf142 --- /dev/null +++ b/pgtype/qchar.go @@ -0,0 +1,145 @@ +package pgtype + +import ( + "fmt" + "math" + "strconv" +) + +// QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C +// language's char type, or Go's byte type. (Note that the name in PostgreSQL +// itself is "char", in double-quotes, and not char.) It gets used a lot in +// PostgreSQL's system tables to hold a single ASCII character value (eg +// pg_class.relkind). It is named Qchar for quoted char to disambiguate from SQL +// standard type char. +// +// Not all possible values of QChar are representable in the text format. +// Therefore, QChar does not implement TextEncoder and TextDecoder. In +// addition, database/sql Scanner and database/sql/driver Value are not +// implemented. +type QChar struct { + Int int8 + Valid bool +} + +func (dst *QChar) Set(src interface{}) error { + if src == nil { + *dst = QChar{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case int8: + *dst = QChar{Int: value, Valid: true} + case uint8: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Valid: true} + case int16: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Valid: true} + case uint16: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Valid: true} + case int32: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Valid: true} + case uint32: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Valid: true} + case int64: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Valid: true} + case uint64: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Valid: true} + case int: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Valid: true} + case uint: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Valid: true} + case string: + num, err := strconv.ParseInt(value, 10, 8) + if err != nil { + return err + } + *dst = QChar{Int: int8(num), Valid: true} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to QChar", value) + } + + return nil +} + +func (dst QChar) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.Int +} + +func (src *QChar) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Valid, dst) +} + +func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = QChar{} + return nil + } + + if len(src) != 1 { + return fmt.Errorf(`invalid length for "char": %v`, len(src)) + } + + *dst = QChar{Int: int8(src[0]), Valid: true} + return nil +} + +func (src QChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, byte(src.Int)), nil +} diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go new file mode 100644 index 00000000..eb54bf65 --- /dev/null +++ b/pgtype/qchar_test.go @@ -0,0 +1,143 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestQCharTranscode(t *testing.T) { + testutil.TestPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ + &pgtype.QChar{Int: math.MinInt8, Valid: true}, + &pgtype.QChar{Int: -1, Valid: true}, + &pgtype.QChar{Int: 0, Valid: true}, + &pgtype.QChar{Int: 1, Valid: true}, + &pgtype.QChar{Int: math.MaxInt8, Valid: true}, + &pgtype.QChar{Int: 0}, + }, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestQCharSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.QChar + }{ + {source: int8(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: int16(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: int32(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: int64(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: int8(-1), result: pgtype.QChar{Int: -1, Valid: true}}, + {source: int16(-1), result: pgtype.QChar{Int: -1, Valid: true}}, + {source: int32(-1), result: pgtype.QChar{Int: -1, Valid: true}}, + {source: int64(-1), result: pgtype.QChar{Int: -1, Valid: true}}, + {source: uint8(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: uint16(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: uint32(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: uint64(1), result: pgtype.QChar{Int: 1, Valid: true}}, + {source: "1", result: pgtype.QChar{Int: 1, Valid: true}}, + {source: _int8(1), result: pgtype.QChar{Int: 1, Valid: true}}, + } + + for i, tt := range successfulTests { + var r pgtype.QChar + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestQCharAssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.QChar + dst interface{} + expected interface{} + }{ + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i16, expected: int16(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i32, expected: int32(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i64, expected: int64(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &i, expected: int(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &ui, expected: uint(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.QChar{Int: 0}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.QChar{Int: 0}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.QChar + dst interface{} + expected interface{} + }{ + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &pi8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Valid: true}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.QChar + dst interface{} + }{ + {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui8}, + {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui16}, + {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui32}, + {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui64}, + {src: pgtype.QChar{Int: -1, Valid: true}, dst: &ui}, + {src: pgtype.QChar{Int: 0}, dst: &i16}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/range.go b/pgtype/range.go new file mode 100644 index 00000000..e999f6a9 --- /dev/null +++ b/pgtype/range.go @@ -0,0 +1,277 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +type BoundType byte + +const ( + Inclusive = BoundType('i') + Exclusive = BoundType('e') + Unbounded = BoundType('U') + Empty = BoundType('E') +) + +func (bt BoundType) String() string { + return string(bt) +} + +type UntypedTextRange struct { + Lower string + Upper string + LowerType BoundType + UpperType BoundType +} + +func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { + utr := &UntypedTextRange{} + if src == "empty" { + utr.LowerType = Empty + utr.UpperType = Empty + return utr, nil + } + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower bound: %v", err) + } + switch r { + case '(': + utr.LowerType = Exclusive + case '[': + utr.LowerType = Inclusive + default: + return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + buf.UnreadRune() + + if r == ',' { + utr.LowerType = Unbounded + } else { + utr.Lower, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing range separator: %v", err) + } + if r != ',' { + return nil, fmt.Errorf("missing range separator: %v", r) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + + if r == ')' || r == ']' { + utr.UpperType = Unbounded + } else { + buf.UnreadRune() + utr.Upper, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing upper bound: %v", err) + } + switch r { + case ')': + utr.UpperType = Exclusive + case ']': + utr.UpperType = Inclusive + default: + return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + return utr, nil +} + +func rangeParseValue(buf *bytes.Buffer) (string, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + if r == '"' { + return rangeParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case ',', '[', ']', '(', ')': + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + if r != '"' { + buf.UnreadRune() + return s.String(), nil + } + } + s.WriteRune(r) + } +} + +type UntypedBinaryRange struct { + Lower []byte + Upper []byte + LowerType BoundType + UpperType BoundType +} + +// 0 = () = 00000 +// 1 = empty = 00001 +// 2 = [) = 00010 +// 4 = (] = 00100 +// 6 = [] = 00110 +// 8 = ) = 01000 +// 12 = ] = 01100 +// 16 = ( = 10000 +// 18 = [ = 10010 +// 24 = = 11000 + +const emptyMask = 1 +const lowerInclusiveMask = 2 +const upperInclusiveMask = 4 +const lowerUnboundedMask = 8 +const upperUnboundedMask = 16 + +func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { + ubr := &UntypedBinaryRange{} + + if len(src) == 0 { + return nil, fmt.Errorf("range too short: %v", len(src)) + } + + rangeType := src[0] + rp := 1 + + if rangeType&emptyMask > 0 { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) + } + ubr.LowerType = Empty + ubr.UpperType = Empty + return ubr, nil + } + + if rangeType&lowerInclusiveMask > 0 { + ubr.LowerType = Inclusive + } else if rangeType&lowerUnboundedMask > 0 { + ubr.LowerType = Unbounded + } else { + ubr.LowerType = Exclusive + } + + if rangeType&upperInclusiveMask > 0 { + ubr.UpperType = Inclusive + } else if rangeType&upperUnboundedMask > 0 { + ubr.UpperType = Unbounded + } else { + ubr.UpperType = Exclusive + } + + if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) + } + return ubr, nil + } + + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + val := src[rp : rp+valueLen] + rp += valueLen + + if ubr.LowerType != Unbounded { + ubr.Lower = val + } else { + ubr.Upper = val + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + return ubr, nil + } + + if ubr.UpperType != Unbounded { + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + ubr.Upper = src[rp : rp+valueLen] + rp += valueLen + } + + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + + return ubr, nil + +} diff --git a/pgtype/range_test.go b/pgtype/range_test.go new file mode 100644 index 00000000..9e16df59 --- /dev/null +++ b/pgtype/range_test.go @@ -0,0 +1,177 @@ +package pgtype + +import ( + "bytes" + "testing" +) + +func TestParseUntypedTextRange(t *testing.T) { + tests := []struct { + src string + result UntypedTextRange + err error + }{ + { + src: `[1,2)`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[1,2]`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: `(1,3)`, + result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: ` [1,2) `, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[ foo , bar )`, + result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["foo","bar")`, + result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["","bar")`, + result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[f\"oo\,,b\\ar\))`, + result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `empty`, + result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedTextRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if r.Lower != tt.result.Lower { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if r.Upper != tt.result.Upper { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} + +func TestParseUntypedBinaryRange(t *testing.T) { + tests := []struct { + src []byte + result UntypedBinaryRange + err error + }{ + { + src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{1}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, + err: nil, + }, + { + src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{8, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{12, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{16, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{18, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{24}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedBinaryRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if bytes.Compare(r.Lower, tt.result.Lower) != 0 { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if bytes.Compare(r.Upper, tt.result.Upper) != 0 { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} diff --git a/pgtype/record.go b/pgtype/record.go new file mode 100644 index 00000000..20b119c6 --- /dev/null +++ b/pgtype/record.go @@ -0,0 +1,119 @@ +package pgtype + +import ( + "fmt" + "reflect" +) + +// Record is the generic PostgreSQL record type such as is created with the +// "row" function. Record only implements BinaryEncoder and Value. The text +// format output format from PostgreSQL does not include type information and is +// therefore impossible to decode. No encoders are implemented because +// PostgreSQL does not support input of generic records. +type Record struct { + Fields []Value + Valid bool +} + +func (dst *Record) Set(src interface{}) error { + if src == nil { + *dst = Record{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case []Value: + *dst = Record{Fields: value, Valid: true} + default: + return fmt.Errorf("cannot convert %v to Record", src) + } + + return nil +} + +func (dst Record) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.Fields +} + +func (src *Record) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *[]Value: + *v = make([]Value, len(src.Fields)) + copy(*v, src.Fields) + return nil + case *[]interface{}: + *v = make([]interface{}, len(src.Fields)) + for i := range *v { + (*v)[i] = src.Fields[i].Get() + } + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDecoder, error) { + var binaryDecoder BinaryDecoder + + if dt, ok := ci.DataTypeForOID(fieldOID); ok { + binaryDecoder, _ = dt.Value.(BinaryDecoder) + } else { + return nil, fmt.Errorf("unknown oid while decoding record: %v", fieldOID) + } + + if binaryDecoder == nil { + return nil, fmt.Errorf("no binary decoder registered for: %v", fieldOID) + } + + // Duplicate struct to scan into + binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder) + *v = binaryDecoder.(Value) + return binaryDecoder, nil +} + +func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Record{} + return nil + } + + scanner := NewCompositeBinaryScanner(ci, src) + + fields := make([]Value, scanner.FieldCount()) + + for i := 0; scanner.Next(); i++ { + binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i]) + if err != nil { + return err + } + + if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil { + return err + } + } + + if scanner.Err() != nil { + return scanner.Err() + } + + *dst = Record{Fields: fields, Valid: true} + + return nil +} diff --git a/pgtype/record_test.go b/pgtype/record_test.go new file mode 100644 index 00000000..c8e7d4b7 --- /dev/null +++ b/pgtype/record_test.go @@ -0,0 +1,184 @@ +package pgtype_test + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" +) + +var recordTests = []struct { + sql string + expected pgtype.Record +}{ + { + sql: `select row()`, + expected: pgtype.Record{ + Fields: []pgtype.Value{}, + Valid: true, + }, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Valid: true}, + &pgtype.Int4{Int: 42, Valid: true}, + }, + Valid: true, + }, + }, + { + sql: `select row(100.0::float4, 1.09::float4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Float4{Float: 100, Valid: true}, + &pgtype.Float4{Float: 1.09, Valid: true}, + }, + Valid: true, + }, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Valid: true}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + {Int: 1, Valid: true}, + {Int: 2, Valid: true}, + {}, + {Int: 4, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, + Valid: true, + }, + &pgtype.Int4{Int: 42, Valid: true}, + }, + Valid: true, + }, + }, + { + sql: `select row(null)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Unknown{}, + }, + Valid: true, + }, + }, + { + sql: `select null::record`, + expected: pgtype.Record{}, + }, +} + +func TestRecordTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for i, tt := range recordTests { + psName := fmt.Sprintf("test%d", i) + _, err := conn.Prepare(context.Background(), psName, tt.sql) + if err != nil { + t.Fatal(err) + } + + t.Run(tt.sql, func(t *testing.T) { + var result pgtype.Record + if err := conn.QueryRow(context.Background(), psName, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&result); err != nil { + t.Errorf("%v", err) + return + } + + if !reflect.DeepEqual(tt.expected, result) { + t.Errorf("expected %#v, got %#v", tt.expected, result) + } + }) + + } +} + +func TestRecordWithUnknownOID(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists floatrange; + +create type floatrange as range ( + subtype = float8, + subtype_diff = float8mi +);`) + if err != nil { + t.Fatal(err) + } + defer conn.Exec(context.Background(), "drop type floatrange") + + var result pgtype.Record + err = conn.QueryRow(context.Background(), "select row('foo'::text, floatrange(1, 10), 'bar'::text)").Scan(&result) + if err == nil { + t.Errorf("expected error but none") + } +} + +func TestRecordAssignTo(t *testing.T) { + var valueSlice []pgtype.Value + var interfaceSlice []interface{} + + simpleTests := []struct { + src pgtype.Record + dst interface{} + expected interface{} + }{ + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Valid: true}, + &pgtype.Int4{Int: 42, Valid: true}, + }, + Valid: true, + }, + dst: &valueSlice, + expected: []pgtype.Value{ + &pgtype.Text{String: "foo", Valid: true}, + &pgtype.Int4{Int: 42, Valid: true}, + }, + }, + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Valid: true}, + &pgtype.Int4{Int: 42, Valid: true}, + }, + Valid: true, + }, + dst: &interfaceSlice, + expected: []interface{}{"foo", int32(42)}, + }, + { + src: pgtype.Record{}, + dst: &valueSlice, + expected: (([]pgtype.Value)(nil)), + }, + { + src: pgtype.Record{}, + dst: &interfaceSlice, + expected: (([]interface{})(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go new file mode 100644 index 00000000..5dded2b9 --- /dev/null +++ b/pgtype/testutil/testutil.go @@ -0,0 +1,425 @@ +package testutil + +import ( + "context" + "database/sql" + "fmt" + "os" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" + _ "github.com/jackc/pgx/v4/stdlib" +) + +func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { + var sqlDriverName string + switch driverName { + case "github.com/jackc/pgx/stdlib": + sqlDriverName = "pgx" + default: + t.Fatalf("Unknown driver %v", driverName) + } + + db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + + return db +} + +func MustConnectPgx(t testing.TB) *pgx.Conn { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + + return conn +} + +func MustClose(t testing.TB, conn interface { + Close() error +}) { + err := conn.Close() + if err != nil { + t.Fatal(err) + } +} + +func MustCloseContext(t testing.TB, conn interface { + Close(context.Context) error +}) { + err := conn.Close(context.Background()) + if err != nil { + t.Fatal(err) + } +} + +type forceTextEncoder struct { + e pgtype.TextEncoder +} + +func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + return f.e.EncodeText(ci, buf) +} + +type forceBinaryEncoder struct { + e pgtype.BinaryEncoder +} + +func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + return f.e.EncodeBinary(ci, buf) +} + +func ForceEncoder(e interface{}, formatCode int16) interface{} { + switch formatCode { + case pgx.TextFormatCode: + if e, ok := e.(pgtype.TextEncoder); ok { + return forceTextEncoder{e: e} + } + case pgx.BinaryFormatCode: + if e, ok := e.(pgtype.BinaryEncoder); ok { + return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} + } + } + return nil +} + +func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { + TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) +} + +func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := MustConnectPgx(t) + defer MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, v := range values { + for _, paramFormat := range formats { + for _, resultFormat := range formats { + vEncoder := ForceEncoder(v, paramFormat.formatCode) + if vEncoder == nil { + t.Logf("Skipping Param %s Result %s: %#v does not implement %v for encoding", paramFormat.name, resultFormat.name, v, paramFormat.name) + continue + } + switch resultFormat.formatCode { + case pgx.TextFormatCode: + if _, ok := v.(pgtype.TextEncoder); !ok { + t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name) + continue + } + case pgx.BinaryFormatCode: + if _, ok := v.(pgtype.BinaryEncoder); !ok { + t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name) + continue + } + } + + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + + err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{resultFormat.formatCode}, vEncoder).Scan(result.Interface()) + if err != nil { + t.Errorf("Param %s Result %s %d: %v", paramFormat.name, resultFormat.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("Param %s Result %s %d: expected %v, got %v", paramFormat.name, resultFormat.name, i, derefV, result.Elem().Interface()) + } + } + } + } +} + +func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := ps.QueryRow(v).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} + +type NormalizeTest struct { + SQL string + Value interface{} +} + +func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) { + TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) + TestDatabaseSQLSuccessfulNormalizeEqFunc(t, "github.com/jackc/pgx/stdlib", tests, eqFunc) +} + +func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + conn := MustConnectPgx(t) + defer MustCloseContext(t, conn) + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, tt := range tests { + for _, fc := range formats { + psName := fmt.Sprintf("test%d", i) + _, err := conn.Prepare(context.Background(), psName, tt.SQL) + if err != nil { + t.Fatal(err) + } + + queryResultFormats := pgx.QueryResultFormats{fc.formatCode} + if ForceEncoder(tt.Value, fc.formatCode) == nil { + t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name) + continue + } + // Derefence value if it is a pointer + derefV := tt.Value + refVal := reflect.ValueOf(tt.Value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = conn.QueryRow(context.Background(), psName, queryResultFormats).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} + +func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + for i, tt := range tests { + ps, err := conn.Prepare(tt.SQL) + if err != nil { + t.Errorf("%d. %v", i, err) + continue + } + + // Derefence value if it is a pointer + derefV := tt.Value + refVal := reflect.ValueOf(tt.Value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = ps.QueryRow().Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} + +func TestGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { + TestPgxGoZeroToNullConversion(t, pgTypeName, zero) + TestDatabaseSQLGoZeroToNullConversion(t, "github.com/jackc/pgx/stdlib", pgTypeName, zero) +} + +func TestNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) { + TestPgxNullToGoZeroConversion(t, pgTypeName, zero) + TestDatabaseSQLNullToGoZeroConversion(t, "github.com/jackc/pgx/stdlib", pgTypeName, zero) +} + +func TestPgxGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { + conn := MustConnectPgx(t) + defer MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s is null", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, paramFormat := range formats { + vEncoder := ForceEncoder(zero, paramFormat.formatCode) + if vEncoder == nil { + t.Logf("Skipping Param %s: %#v does not implement %v for encoding", paramFormat.name, zero, paramFormat.name) + continue + } + + var result bool + err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(&result) + if err != nil { + t.Errorf("Param %s: %v", paramFormat.name, err) + } + + if !result { + t.Errorf("Param %s: did not convert zero to null", paramFormat.name) + } + } +} + +func TestPgxNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) { + conn := MustConnectPgx(t) + defer MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select null::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, resultFormat := range formats { + + switch resultFormat.formatCode { + case pgx.TextFormatCode: + if _, ok := zero.(pgtype.TextEncoder); !ok { + t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name) + continue + } + case pgx.BinaryFormatCode: + if _, ok := zero.(pgtype.BinaryEncoder); !ok { + t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name) + continue + } + } + + // Derefence value if it is a pointer + derefZero := zero + refVal := reflect.ValueOf(zero) + if refVal.Kind() == reflect.Ptr { + derefZero = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefZero)) + + err := conn.QueryRow(context.Background(), "test").Scan(result.Interface()) + if err != nil { + t.Errorf("Result %s: %v", resultFormat.name, err) + } + + if !reflect.DeepEqual(result.Elem().Interface(), derefZero) { + t.Errorf("Result %s: did not convert null to zero", resultFormat.name) + } + } +} + +func TestDatabaseSQLGoZeroToNullConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s is null", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + var result bool + err = ps.QueryRow(zero).Scan(&result) + if err != nil { + t.Errorf("%v %v", driverName, err) + } + + if !result { + t.Errorf("%v: did not convert zero to null", driverName) + } +} + +func TestDatabaseSQLNullToGoZeroConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select null::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + // Derefence value if it is a pointer + derefZero := zero + refVal := reflect.ValueOf(zero) + if refVal.Kind() == reflect.Ptr { + derefZero = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefZero)) + + err = ps.QueryRow().Scan(result.Interface()) + if err != nil { + t.Errorf("%v %v", driverName, err) + } + + if !reflect.DeepEqual(result.Elem().Interface(), derefZero) { + t.Errorf("%s: did not convert null to zero", driverName) + } +} diff --git a/pgtype/text.go b/pgtype/text.go new file mode 100644 index 00000000..5d27c44f --- /dev/null +++ b/pgtype/text.go @@ -0,0 +1,193 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/json" + "fmt" +) + +type Text struct { + String string + Valid bool +} + +func (dst *Text) Set(src interface{}) error { + if src == nil { + *dst = Text{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case string: + *dst = Text{String: value, Valid: true} + case *string: + if value == nil { + *dst = Text{} + } else { + *dst = Text{String: *value, Valid: true} + } + case []byte: + if value == nil { + *dst = Text{} + } else { + *dst = Text{String: string(value), Valid: true} + } + case fmt.Stringer: + if value == fmt.Stringer(nil) { + *dst = Text{} + } else { + *dst = Text{String: value.String(), Valid: true} + } + default: + // Cannot be part of the switch: If Value() returns nil on + // non-string, we should still try to checks the underlying type + // using reflection. + // + // For example the struct might implement driver.Valuer with + // pointer receiver and fmt.Stringer with value receiver. + if value, ok := src.(driver.Valuer); ok { + if value == driver.Valuer(nil) { + *dst = Text{} + return nil + } else { + v, err := value.Value() + if err != nil { + return fmt.Errorf("driver.Valuer Value() method failed: %w", err) + } + + // Handles also v == nil case. + if s, ok := v.(string); ok { + *dst = Text{String: s, Valid: true} + return nil + } + } + } + + if originalSrc, ok := underlyingStringType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Text", value) + } + + return nil +} + +func (dst Text) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.String +} + +func (src *Text) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *string: + *v = src.String + return nil + case *[]byte: + *v = make([]byte, len(src.String)) + copy(*v, src.String) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +func (Text) PreferredResultFormat() int16 { + return TextFormatCode +} + +func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Text{} + return nil + } + + *dst = Text{String: string(src), Valid: true} + return nil +} + +func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) +} + +func (Text) PreferredParamFormat() int16 { + return TextFormatCode +} + +func (src Text) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, src.String...), nil +} + +func (src Text) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Text) Scan(src interface{}) error { + if src == nil { + *dst = Text{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Text) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + return src.String, nil +} + +func (src Text) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + return json.Marshal(src.String) +} + +func (dst *Text) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Text{} + } else { + *dst = Text{String: *s, Valid: true} + } + + return nil +} diff --git a/pgtype/text_array.go b/pgtype/text_array.go new file mode 100644 index 00000000..7fcc1c4d --- /dev/null +++ b/pgtype/text_array.go @@ -0,0 +1,504 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type TextArray struct { + Elements []Text + Dimensions []ArrayDimension + Valid bool +} + +func (dst *TextArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TextArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []string: + if value == nil { + *dst = TextArray{} + } else if len(value) == 0 { + *dst = TextArray{Valid: true} + } else { + elements := make([]Text, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TextArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*string: + if value == nil { + *dst = TextArray{} + } else if len(value) == 0 { + *dst = TextArray{Valid: true} + } else { + elements := make([]Text, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TextArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Text: + if value == nil { + *dst = TextArray{} + } else if len(value) == 0 { + *dst = TextArray{Valid: true} + } else { + *dst = TextArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TextArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for TextArray", src) + } + if elementsLength == 0 { + *dst = TextArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to TextArray", src) + } + + *dst = TextArray{ + Elements: make([]Text, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Text, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to TextArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *TextArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to TextArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in TextArray", err) + } + index++ + + return index, nil +} + +func (dst TextArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *TextArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from TextArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from TextArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TextArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Text + + if len(uta.Elements) > 0 { + elements = make([]Text, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Text + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TextArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TextArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TextArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Text, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TextArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("text"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "text") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TextArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TextArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/text_array_test.go b/pgtype/text_array_test.go new file mode 100644 index 00000000..4caeb692 --- /dev/null +++ b/pgtype/text_array_test.go @@ -0,0 +1,294 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// https://github.com/jackc/pgtype/issues/78 +func TestTextArrayDecodeTextNull(t *testing.T) { + textArray := &pgtype.TextArray{} + err := textArray.DecodeText(nil, []byte(`{abc,"NULL",NULL,def}`)) + require.NoError(t, err) + require.Len(t, textArray.Elements, 4) + assert.Equal(t, true, textArray.Elements[1].Valid) + assert.Equal(t, false, textArray.Elements[2].Valid) +} + +func TestTextArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "text[]", []interface{}{ + &pgtype.TextArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.TextArray{}, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "bar ", Valid: true}, + {String: "NuLL", Valid: true}, + {String: `wow"quz\`, Valid: true}, + {String: "", Valid: true}, + {}, + {String: "null", Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "quz", Valid: true}, + {String: "foo", Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestTextArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TextArray + }{ + { + source: []string{"foo"}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]string)(nil)), + result: pgtype.TextArray{}, + }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TextArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTextArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string + + simpleTests := []struct { + src pgtype.TextArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.TextArray{}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + { + src: pgtype.TextArray{Valid: true}, + dst: &stringSlice, + expected: []string{}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{ + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TextArray + dst interface{} + }{ + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &stringSlice, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &stringArrayDim2, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &stringSlice, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &stringArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/text_test.go b/pgtype/text_test.go new file mode 100644 index 00000000..5f34f8c0 --- /dev/null +++ b/pgtype/text_test.go @@ -0,0 +1,164 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestTextTranscode(t *testing.T) { + for _, pgTypeName := range []string{"text", "varchar"} { + testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ + &pgtype.Text{String: "", Valid: true}, + &pgtype.Text{String: "foo", Valid: true}, + &pgtype.Text{}, + }) + } +} + +func TestTextSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Text + }{ + {source: "foo", result: pgtype.Text{String: "foo", Valid: true}}, + {source: _string("bar"), result: pgtype.Text{String: "bar", Valid: true}}, + {source: (*string)(nil), result: pgtype.Text{}}, + } + + for i, tt := range successfulTests { + var d pgtype.Text + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestTextAssignTo(t *testing.T) { + var s string + var ps *string + + stringTests := []struct { + src pgtype.Text + dst interface{} + expected interface{} + }{ + {src: pgtype.Text{String: "foo", Valid: true}, dst: &s, expected: "foo"}, + {src: pgtype.Text{}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range stringTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + var buf []byte + + bytesTests := []struct { + src pgtype.Text + dst *[]byte + expected []byte + }{ + {src: pgtype.Text{String: "foo", Valid: true}, dst: &buf, expected: []byte("foo")}, + {src: pgtype.Text{}, dst: &buf, expected: nil}, + } + + for i, tt := range bytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(*tt.dst, tt.expected) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Text + dst interface{} + expected interface{} + }{ + {src: pgtype.Text{String: "foo", Valid: true}, dst: &ps, expected: "foo"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Text + dst interface{} + }{ + {src: pgtype.Text{}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} + +func TestTextMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Text + result string + }{ + {source: pgtype.Text{String: ""}, result: "null"}, + {source: pgtype.Text{String: "a", Valid: true}, result: "\"a\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestTextUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Text + }{ + {source: "null", result: pgtype.Text{String: ""}}, + {source: "\"a\"", result: pgtype.Text{String: "a", Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Text + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/tid.go b/pgtype/tid.go new file mode 100644 index 00000000..0108d219 --- /dev/null +++ b/pgtype/tid.go @@ -0,0 +1,146 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +// TID is PostgreSQL's Tuple Identifier type. +// +// When one does +// +// select ctid, * from some_table; +// +// it is the data type of the ctid hidden system column. +// +// It is currently implemented as a pair unsigned two byte integers. +// Its conversion functions can be found in src/backend/utils/adt/tid.c +// in the PostgreSQL sources. +type TID struct { + BlockNumber uint32 + OffsetNumber uint16 + Valid bool +} + +func (dst *TID) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to TID", src) +} + +func (dst TID) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *TID) AssignTo(dst interface{}) error { + if !src.Valid { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + + switch v := dst.(type) { + case *string: + *v = fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TID{} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for tid") + } + + blockNumber, err := strconv.ParseUint(parts[0], 10, 32) + if err != nil { + return err + } + + offsetNumber, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return err + } + + *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Valid: true} + return nil +} + +func (dst *TID) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TID{} + return nil + } + + if len(src) != 6 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + *dst = TID{ + BlockNumber: binary.BigEndian.Uint32(src), + OffsetNumber: binary.BigEndian.Uint16(src[4:]), + Valid: true, + } + return nil +} + +func (src TID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber)...) + return buf, nil +} + +func (src TID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = pgio.AppendUint32(buf, src.BlockNumber) + buf = pgio.AppendUint16(buf, src.OffsetNumber) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TID) Scan(src interface{}) error { + if src == nil { + *dst = TID{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TID) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go new file mode 100644 index 00000000..fcf93259 --- /dev/null +++ b/pgtype/tid_test.go @@ -0,0 +1,62 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestTIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "tid", []interface{}{ + &pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, + &pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, + &pgtype.TID{}, + }) +} + +func TestTIDAssignTo(t *testing.T) { + var s string + var sp *string + + simpleTests := []struct { + src pgtype.TID + dst interface{} + expected interface{} + }{ + {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, dst: &s, expected: "(42,43)"}, + {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, dst: &s, expected: "(4294967295,65535)"}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.TID + dst interface{} + expected interface{} + }{ + {src: pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, dst: &sp, expected: "(42,43)"}, + {src: pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, dst: &sp, expected: "(4294967295,65535)"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/time.go b/pgtype/time.go new file mode 100644 index 00000000..3252a633 --- /dev/null +++ b/pgtype/time.go @@ -0,0 +1,218 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "time" + + "github.com/jackc/pgio" +) + +// Time represents the PostgreSQL time type. The PostgreSQL time is a time of day without time zone. +// +// Time is represented as the number of microseconds since midnight in the same way that PostgreSQL does. Other time +// and date types in pgtype can use time.Time as the underlying representation. However, pgtype.Time type cannot due +// to needing to handle 24:00:00. time.Time converts that to 00:00:00 on the following day. +type Time struct { + Microseconds int64 // Number of microseconds since midnight + Valid bool +} + +// Set converts src into a Time and stores in dst. +func (dst *Time) Set(src interface{}) error { + if src == nil { + *dst = Time{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case time.Time: + usec := int64(value.Hour())*microsecondsPerHour + + int64(value.Minute())*microsecondsPerMinute + + int64(value.Second())*microsecondsPerSecond + + int64(value.Nanosecond())/1000 + *dst = Time{Microseconds: usec, Valid: true} + case *time.Time: + if value == nil { + *dst = Time{} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Time", value) + } + + return nil +} + +func (dst Time) Get() interface{} { + if !dst.Valid { + return nil + } + return dst.Microseconds +} + +func (src *Time) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *time.Time: + // 24:00:00 is max allowed time in PostgreSQL, but time.Time will normalize that to 00:00:00 the next day. + var maxRepresentableByTime int64 = 24*60*60*1000000 - 1 + if src.Microseconds > maxRepresentableByTime { + return fmt.Errorf("%d microseconds cannot be represented as time.Time", src.Microseconds) + } + + usec := src.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + ns := usec * 1000 + *v = time.Date(2000, 1, 1, int(hours), int(minutes), int(seconds), int(ns), time.UTC) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +// DecodeText decodes from src into dst. +func (dst *Time) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Time{} + return nil + } + + s := string(src) + + if len(s) < 8 { + return fmt.Errorf("cannot decode %v into Time", s) + } + + hours, err := strconv.ParseInt(s[0:2], 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + usec := hours * microsecondsPerHour + + minutes, err := strconv.ParseInt(s[3:5], 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + usec += minutes * microsecondsPerMinute + + seconds, err := strconv.ParseInt(s[6:8], 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + usec += seconds * microsecondsPerSecond + + if len(s) > 9 { + fraction := s[9:] + n, err := strconv.ParseInt(fraction, 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + + for i := len(fraction); i < 6; i++ { + n *= 10 + } + + usec += n + } + + *dst = Time{Microseconds: usec, Valid: true} + + return nil +} + +// DecodeBinary decodes from src into dst. +func (dst *Time) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Time{} + return nil + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for time: %v", len(src)) + } + + usec := int64(binary.BigEndian.Uint64(src)) + *dst = Time{Microseconds: usec, Valid: true} + + return nil +} + +// EncodeText writes the text encoding of src into w. +func (src Time) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + usec := src.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + + s := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, usec) + + return append(buf, s...), nil +} + +// EncodeBinary writes the binary encoding of src into w. If src.Time is not in +// the UTC time zone it returns an error. +func (src Time) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return pgio.AppendInt64(buf, src.Microseconds), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Time) Scan(src interface{}) error { + if src == nil { + *dst = Time{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + case time.Time: + return dst.Set(src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Time) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/time_test.go b/pgtype/time_test.go new file mode 100644 index 00000000..4a989375 --- /dev/null +++ b/pgtype/time_test.go @@ -0,0 +1,117 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestTimeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "time", []interface{}{ + &pgtype.Time{Microseconds: 0, Valid: true}, + &pgtype.Time{Microseconds: 1, Valid: true}, + &pgtype.Time{Microseconds: 86399999999, Valid: true}, + &pgtype.Time{Microseconds: 86400000000, Valid: true}, + &pgtype.Time{}, + }) +} + +func TestTimeSet(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Time + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 0, Valid: true}}, + {source: time.Date(1900, 1, 1, 1, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 3600000000, Valid: true}}, + {source: time.Date(1900, 1, 1, 0, 1, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 60000000, Valid: true}}, + {source: time.Date(1900, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Time{Microseconds: 1000000, Valid: true}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 1, time.UTC), result: pgtype.Time{Microseconds: 0, Valid: true}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 1000, time.UTC), result: pgtype.Time{Microseconds: 1, Valid: true}}, + {source: time.Date(1999, 12, 31, 23, 59, 59, 999999999, time.UTC), result: pgtype.Time{Microseconds: 86399999999, Valid: true}}, + {source: time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local), result: pgtype.Time{Microseconds: 2, Valid: true}}, + {source: func(t time.Time) *time.Time { return &t }(time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local)), result: pgtype.Time{Microseconds: 2, Valid: true}}, + {source: nil, result: pgtype.Time{}}, + {source: (*time.Time)(nil), result: pgtype.Time{}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 3000, time.UTC)), result: pgtype.Time{Microseconds: 3, Valid: true}}, + } + + for i, tt := range successfulTests { + var r pgtype.Time + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimeAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Time + dst interface{} + expected interface{} + }{ + {src: pgtype.Time{Microseconds: 0, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 3600000000, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 60000000, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 1, 0, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 1000000, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC)}, + {src: pgtype.Time{Microseconds: 1, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 1000, time.UTC)}, + {src: pgtype.Time{Microseconds: 86399999999, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 23, 59, 59, 999999000, time.UTC)}, + {src: pgtype.Time{Microseconds: 0}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Time + dst interface{} + expected interface{} + }{ + {src: pgtype.Time{Microseconds: 0, Valid: true}, dst: &ptim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Time + dst interface{} + }{ + {src: pgtype.Time{Microseconds: 86400000000, Valid: true}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go new file mode 100644 index 00000000..882cd41a --- /dev/null +++ b/pgtype/timestamp.go @@ -0,0 +1,227 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "time" + + "github.com/jackc/pgio" +) + +const pgTimestampFormat = "2006-01-02 15:04:05.999999999" + +// Timestamp represents the PostgreSQL timestamp type. The PostgreSQL +// timestamp does not have a time zone. This presents a problem when +// translating to and from time.Time which requires a time zone. It is highly +// recommended to use timestamptz whenever possible. Timestamp methods either +// convert to UTC or return an error on non-UTC times. +type Timestamp struct { + Time time.Time // Time must always be in UTC. + Valid bool + InfinityModifier InfinityModifier +} + +// Set converts src into a Timestamp and stores in dst. If src is a +// time.Time in a non-UTC time zone, the time zone is discarded. +func (dst *Timestamp) Set(src interface{}) error { + if src == nil { + *dst = Timestamp{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case time.Time: + *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Valid: true} + case *time.Time: + if value == nil { + *dst = Timestamp{} + } else { + return dst.Set(*value) + } + case InfinityModifier: + *dst = Timestamp{InfinityModifier: value, Valid: true} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamp", value) + } + + return nil +} + +func (dst Timestamp) Get() interface{} { + if !dst.Valid { + return nil + } + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time +} + +func (src *Timestamp) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +// DecodeText decodes from src into dst. The decoded time is considered to +// be in UTC. +func (dst *Timestamp) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Timestamp{} + return nil + } + + sbuf := string(src) + switch sbuf { + case "infinity": + *dst = Timestamp{Valid: true, InfinityModifier: Infinity} + case "-infinity": + *dst = Timestamp{Valid: true, InfinityModifier: -Infinity} + default: + tim, err := time.Parse(pgTimestampFormat, sbuf) + if err != nil { + return err + } + + *dst = Timestamp{Time: tim, Valid: true} + } + + return nil +} + +// DecodeBinary decodes from src into dst. The decoded time is considered to +// be in UTC. +func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Timestamp{} + return nil + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamp: %v", len(src)) + } + + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + *dst = Timestamp{Valid: true, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + *dst = Timestamp{Valid: true, InfinityModifier: -Infinity} + default: + tim := time.Unix( + microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, + (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + ).UTC() + *dst = Timestamp{Time: tim, Valid: true} + } + + return nil +} + +// EncodeText writes the text encoding of src into w. If src.Time is not in +// the UTC time zone it returns an error. +func (src Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + if src.Time.Location() != time.UTC { + return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Truncate(time.Microsecond).Format(pgTimestampFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return append(buf, s...), nil +} + +// EncodeBinary writes the binary encoding of src into w. If src.Time is not in +// the UTC time zone it returns an error. +func (src Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + if src.Time.Location() != time.UTC { + return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") + } + + var microsecSinceY2K int64 + switch src.InfinityModifier { + case None: + microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + return pgio.AppendInt64(buf, microsecSinceY2K), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamp) Scan(src interface{}) error { + if src == nil { + *dst = Timestamp{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + case time.Time: + *dst = Timestamp{Time: src, Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamp) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil +} diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go new file mode 100644 index 00000000..fbf7c48a --- /dev/null +++ b/pgtype/timestamp_array.go @@ -0,0 +1,505 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + "time" + + "github.com/jackc/pgio" +) + +type TimestampArray struct { + Elements []Timestamp + Dimensions []ArrayDimension + Valid bool +} + +func (dst *TimestampArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TimestampArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []time.Time: + if value == nil { + *dst = TimestampArray{} + } else if len(value) == 0 { + *dst = TimestampArray{Valid: true} + } else { + elements := make([]Timestamp, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestampArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*time.Time: + if value == nil { + *dst = TimestampArray{} + } else if len(value) == 0 { + *dst = TimestampArray{Valid: true} + } else { + elements := make([]Timestamp, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestampArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Timestamp: + if value == nil { + *dst = TimestampArray{} + } else if len(value) == 0 { + *dst = TimestampArray{Valid: true} + } else { + *dst = TimestampArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TimestampArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for TimestampArray", src) + } + if elementsLength == 0 { + *dst = TimestampArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to TimestampArray", src) + } + + *dst = TimestampArray{ + Elements: make([]Timestamp, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Timestamp, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to TimestampArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *TimestampArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to TimestampArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in TimestampArray", err) + } + index++ + + return index, nil +} + +func (dst TimestampArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *TimestampArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from TimestampArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from TimestampArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TimestampArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Timestamp + + if len(uta.Elements) > 0 { + elements = make([]Timestamp, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Timestamp + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TimestampArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TimestampArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TimestampArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Timestamp, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TimestampArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("timestamp"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "timestamp") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TimestampArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TimestampArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/timestamp_array_test.go b/pgtype/timestamp_array_test.go new file mode 100644 index 00000000..214c8a71 --- /dev/null +++ b/pgtype/timestamp_array_test.go @@ -0,0 +1,307 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestTimestampArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp[]", []interface{}{ + &pgtype.TimestampArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.TimestampArray{}, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }, func(a, b interface{}) bool { + ata := a.(pgtype.TimestampArray) + bta := b.(pgtype.TimestampArray) + + if len(ata.Elements) != len(bta.Elements) || ata.Valid != bta.Valid { + return false + } + + for i := range ata.Elements { + ae, be := ata.Elements[i], bta.Elements[i] + if !(ae.Time.Equal(be.Time) && ae.Valid == be.Valid && ae.InfinityModifier == be.InfinityModifier) { + return false + } + } + + return true + }) +} + +func TestTimestampArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TimestampArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.TimestampArray{}, + }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TimestampArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestampArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time + + simpleTests := []struct { + src pgtype.TimestampArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.TimestampArray{}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + { + src: pgtype.TimestampArray{Valid: true}, + dst: &timeSlice, + expected: []time.Time{}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TimestampArray + dst interface{} + }{ + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &timeSlice, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &timeArrayDim2, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &timeSlice, + }, + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &timeArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go new file mode 100644 index 00000000..88e2bca8 --- /dev/null +++ b/pgtype/timestamp_test.go @@ -0,0 +1,199 @@ +package pgtype_test + +import ( + "context" + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestTimestampTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ + &pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + &pgtype.Timestamp{}, + &pgtype.Timestamp{Valid: true, InfinityModifier: pgtype.Infinity}, + &pgtype.Timestamp{Valid: true, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Timestamp) + bt := b.(pgtype.Timestamp) + + return at.Time.Equal(bt.Time) && at.Valid == bt.Valid && at.InfinityModifier == bt.InfinityModifier + }) +} + +// https://github.com/jackc/pgtype/pull/128 +func TestTimestampTranscodeBigTimeBinary(t *testing.T) { + conn := testutil.MustConnectPgx(t) + if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { + t.Skip("Skipping due to no line type") + } + defer testutil.MustCloseContext(t, conn) + + in := &pgtype.Timestamp{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} + var out pgtype.Timestamp + + err := conn.QueryRow(context.Background(), "select $1::timestamptz", in).Scan(&out) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, in.Valid, out.Valid) + require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) +} + +func TestTimestampNanosecondsTruncated(t *testing.T) { + tests := []struct { + input time.Time + expected time.Time + }{ + {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.UTC), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC)}, + {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.UTC), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.UTC)}, + } + for i, tt := range tests { + { + ts := pgtype.Timestamp{Time: tt.input, Valid: true} + buf, err := ts.EncodeText(nil, nil) + if err != nil { + t.Errorf("%d. EncodeText failed - %v", i, err) + } + + ts.DecodeText(nil, buf) + if err != nil { + t.Errorf("%d. DecodeText failed - %v", i, err) + } + + if !(ts.Valid && ts.Time.Equal(tt.expected)) { + t.Errorf("%d. EncodeText did not truncate nanoseconds", i) + } + } + + { + ts := pgtype.Timestamp{Time: tt.input, Valid: true} + buf, err := ts.EncodeBinary(nil, nil) + if err != nil { + t.Errorf("%d. EncodeBinary failed - %v", i, err) + } + + ts.DecodeBinary(nil, buf) + if err != nil { + t.Errorf("%d. DecodeBinary failed - %v", i, err) + } + + if !(ts.Valid && ts.Time.Equal(tt.expected)) { + t.Errorf("%d. EncodeBinary did not truncate nanoseconds", i) + } + } + } +} + +// https://github.com/jackc/pgtype/issues/74 +func TestTimestampDecodeTextInvalid(t *testing.T) { + tstz := &pgtype.Timestamp{} + err := tstz.DecodeText(nil, []byte(`eeeee`)) + require.Error(t, err) +} + +func TestTimestampSet(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Timestamp + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), Valid: true}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), Valid: true}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + {source: pgtype.Infinity, result: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: pgtype.NegativeInfinity, result: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, + } + + for i, tt := range successfulTests { + var r pgtype.Timestamp + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestampAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Timestamp + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, + {src: pgtype.Timestamp{Time: time.Time{}}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Timestamp + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Timestamp + dst interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Valid: true}, dst: &tim}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Valid: true}, dst: &tim}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go new file mode 100644 index 00000000..2a711ffa --- /dev/null +++ b/pgtype/timestamptz.go @@ -0,0 +1,273 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "time" + + "github.com/jackc/pgio" +) + +const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" +const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" +const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" +const microsecFromUnixEpochToY2K = 946684800 * 1000000 + +const ( + negativeInfinityMicrosecondOffset = -9223372036854775808 + infinityMicrosecondOffset = 9223372036854775807 +) + +type Timestamptz struct { + Time time.Time + Valid bool + InfinityModifier InfinityModifier +} + +func (dst *Timestamptz) Set(src interface{}) error { + if src == nil { + *dst = Timestamptz{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case time.Time: + *dst = Timestamptz{Time: value, Valid: true} + case *time.Time: + if value == nil { + *dst = Timestamptz{} + } else { + return dst.Set(*value) + } + case InfinityModifier: + *dst = Timestamptz{InfinityModifier: value, Valid: true} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamptz", value) + } + + return nil +} + +func (dst Timestamptz) Get() interface{} { + if !dst.Valid { + return nil + } + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time +} + +func (src *Timestamptz) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } +} + +func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Timestamptz{} + return nil + } + + sbuf := string(src) + switch sbuf { + case "infinity": + *dst = Timestamptz{Valid: true, InfinityModifier: Infinity} + case "-infinity": + *dst = Timestamptz{Valid: true, InfinityModifier: -Infinity} + default: + var format string + if len(sbuf) >= 9 && (sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+') { + format = pgTimestamptzSecondFormat + } else if len(sbuf) >= 6 && (sbuf[len(sbuf)-6] == '-' || sbuf[len(sbuf)-6] == '+') { + format = pgTimestamptzMinuteFormat + } else { + format = pgTimestamptzHourFormat + } + + tim, err := time.Parse(format, sbuf) + if err != nil { + return err + } + + *dst = Timestamptz{Time: tim, Valid: true} + } + + return nil +} + +func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Timestamptz{} + return nil + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamptz: %v", len(src)) + } + + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + *dst = Timestamptz{Valid: true, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + *dst = Timestamptz{Valid: true, InfinityModifier: -Infinity} + default: + tim := time.Unix( + microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, + (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + ) + *dst = Timestamptz{Time: tim, Valid: true} + } + + return nil +} + +func (src Timestamptz) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.UTC().Truncate(time.Microsecond).Format(pgTimestamptzSecondFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return append(buf, s...), nil +} + +func (src Timestamptz) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var microsecSinceY2K int64 + switch src.InfinityModifier { + case None: + microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + return pgio.AppendInt64(buf, microsecSinceY2K), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamptz) Scan(src interface{}) error { + if src == nil { + *dst = Timestamptz{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + case time.Time: + *dst = Timestamptz{Time: src, Valid: true} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamptz) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil +} + +func (src Timestamptz) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format(time.RFC3339Nano) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (dst *Timestamptz) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Timestamptz{} + return nil + } + + switch *s { + case "infinity": + *dst = Timestamptz{Valid: true, InfinityModifier: Infinity} + case "-infinity": + *dst = Timestamptz{Valid: true, InfinityModifier: -Infinity} + default: + // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz + tim, err := time.Parse(time.RFC3339Nano, *s) + if err != nil { + return err + } + + *dst = Timestamptz{Time: tim, Valid: true} + } + + return nil +} diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go new file mode 100644 index 00000000..4523b251 --- /dev/null +++ b/pgtype/timestamptz_array.go @@ -0,0 +1,505 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + "time" + + "github.com/jackc/pgio" +) + +type TimestamptzArray struct { + Elements []Timestamptz + Dimensions []ArrayDimension + Valid bool +} + +func (dst *TimestamptzArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TimestamptzArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []time.Time: + if value == nil { + *dst = TimestamptzArray{} + } else if len(value) == 0 { + *dst = TimestamptzArray{Valid: true} + } else { + elements := make([]Timestamptz, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestamptzArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*time.Time: + if value == nil { + *dst = TimestamptzArray{} + } else if len(value) == 0 { + *dst = TimestamptzArray{Valid: true} + } else { + elements := make([]Timestamptz, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestamptzArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Timestamptz: + if value == nil { + *dst = TimestamptzArray{} + } else if len(value) == 0 { + *dst = TimestamptzArray{Valid: true} + } else { + *dst = TimestamptzArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TimestamptzArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for TimestamptzArray", src) + } + if elementsLength == 0 { + *dst = TimestamptzArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to TimestamptzArray", src) + } + + *dst = TimestamptzArray{ + Elements: make([]Timestamptz, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Timestamptz, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to TimestamptzArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *TimestamptzArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to TimestamptzArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in TimestamptzArray", err) + } + index++ + + return index, nil +} + +func (dst TimestamptzArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *TimestamptzArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from TimestamptzArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from TimestamptzArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TimestamptzArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Timestamptz + + if len(uta.Elements) > 0 { + elements = make([]Timestamptz, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Timestamptz + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TimestamptzArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TimestamptzArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TimestamptzArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Timestamptz, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TimestamptzArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("timestamptz"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "timestamptz") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TimestamptzArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TimestamptzArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/timestamptz_array_test.go b/pgtype/timestamptz_array_test.go new file mode 100644 index 00000000..22e07b59 --- /dev/null +++ b/pgtype/timestamptz_array_test.go @@ -0,0 +1,343 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestTimestamptzArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz[]", []interface{}{ + &pgtype.TimestamptzArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.TimestamptzArray{}, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {}, + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }, func(a, b interface{}) bool { + ata := a.(pgtype.TimestamptzArray) + bta := b.(pgtype.TimestamptzArray) + + if len(ata.Elements) != len(bta.Elements) || ata.Valid != bta.Valid { + return false + } + + for i := range ata.Elements { + ae, be := ata.Elements[i], bta.Elements[i] + if !(ae.Time.Equal(be.Time) && ae.Valid == be.Valid && ae.InfinityModifier == be.InfinityModifier) { + return false + } + } + + return true + }) +} + +func TestTimestamptzArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TimestamptzArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.TimestamptzArray{}, + }, + { + source: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TimestamptzArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestamptzArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + var timeSliceDim2 [][]time.Time + var timeSliceDim4 [][][][]time.Time + var timeArrayDim2 [2][1]time.Time + var timeArrayDim4 [2][1][1][3]time.Time + + simpleTests := []struct { + src pgtype.TimestamptzArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.TimestamptzArray{}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + { + src: pgtype.TimestamptzArray{Valid: true}, + dst: &timeSlice, + expected: []time.Time{}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &timeSliceDim2, + expected: [][]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &timeSliceDim4, + expected: [][][][]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &timeArrayDim2, + expected: [2][1]time.Time{ + {time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + {time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &timeArrayDim4, + expected: [2][1][1][3]time.Time{ + {{{ + time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), + time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), + time.Date(2017, 5, 6, 0, 0, 0, 0, time.UTC)}}}, + {{{ + time.Date(2018, 7, 8, 0, 0, 0, 0, time.UTC), + time.Date(2019, 9, 10, 0, 0, 0, 0, time.UTC), + time.Date(2020, 11, 12, 0, 0, 0, 0, time.UTC)}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TimestamptzArray + dst interface{} + }{ + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &timeSlice, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &timeArrayDim2, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &timeSlice, + }, + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + {Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + {Time: time.Date(2016, 3, 4, 0, 0, 0, 0, time.UTC), Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &timeArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go new file mode 100644 index 00000000..fa2a7e89 --- /dev/null +++ b/pgtype/timestamptz_test.go @@ -0,0 +1,245 @@ +package pgtype_test + +import ( + "context" + "reflect" + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestTimestamptzTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ + &pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, + &pgtype.Timestamptz{}, + &pgtype.Timestamptz{Valid: true, InfinityModifier: pgtype.Infinity}, + &pgtype.Timestamptz{Valid: true, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Timestamptz) + bt := b.(pgtype.Timestamptz) + + return at.Time.Equal(bt.Time) && at.Valid == bt.Valid && at.InfinityModifier == bt.InfinityModifier + }) +} + +// https://github.com/jackc/pgtype/pull/128 +func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) { + conn := testutil.MustConnectPgx(t) + if _, ok := conn.ConnInfo().DataTypeForName("line"); !ok { + t.Skip("Skipping due to no line type") + } + defer testutil.MustCloseContext(t, conn) + + in := &pgtype.Timestamptz{Time: time.Date(294276, 12, 31, 23, 59, 59, 999999000, time.UTC), Valid: true} + var out pgtype.Timestamptz + + err := conn.QueryRow(context.Background(), "select $1::timestamptz", in).Scan(&out) + if err != nil { + t.Fatal(err) + } + + require.Equal(t, in.Valid, out.Valid) + require.Truef(t, in.Time.Equal(out.Time), "expected %v got %v", in.Time, out.Time) +} + +func TestTimestamptzNanosecondsTruncated(t *testing.T) { + tests := []struct { + input time.Time + expected time.Time + }{ + {time.Date(2020, 1, 1, 0, 0, 0, 999999999, time.Local), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local)}, + {time.Date(2020, 1, 1, 0, 0, 0, 999999001, time.Local), time.Date(2020, 1, 1, 0, 0, 0, 999999000, time.Local)}, + } + for i, tt := range tests { + { + tstz := pgtype.Timestamptz{Time: tt.input, Valid: true} + buf, err := tstz.EncodeText(nil, nil) + if err != nil { + t.Errorf("%d. EncodeText failed - %v", i, err) + } + + tstz.DecodeText(nil, buf) + if err != nil { + t.Errorf("%d. DecodeText failed - %v", i, err) + } + + if !(tstz.Valid && tstz.Time.Equal(tt.expected)) { + t.Errorf("%d. EncodeText did not truncate nanoseconds", i) + } + } + + { + tstz := pgtype.Timestamptz{Time: tt.input, Valid: true} + buf, err := tstz.EncodeBinary(nil, nil) + if err != nil { + t.Errorf("%d. EncodeBinary failed - %v", i, err) + } + + tstz.DecodeBinary(nil, buf) + if err != nil { + t.Errorf("%d. DecodeBinary failed - %v", i, err) + } + + if !(tstz.Valid && tstz.Time.Equal(tt.expected)) { + t.Errorf("%d. EncodeBinary did not truncate nanoseconds", i) + } + } + } +} + +// https://github.com/jackc/pgtype/issues/74 +func TestTimestamptzDecodeTextInvalid(t *testing.T) { + tstz := &pgtype.Timestamptz{} + err := tstz.DecodeText(nil, []byte(`eeeee`)) + require.Error(t, err) +} + +func TestTimestamptzSet(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Timestamptz + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Valid: true}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Valid: true}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: pgtype.Infinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: pgtype.NegativeInfinity, result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, + } + + for i, tt := range successfulTests { + var r pgtype.Timestamptz + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestamptzAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Timestamptz + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Timestamptz{Time: time.Time{}}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Timestamptz + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Timestamptz + dst interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Valid: true}, dst: &tim}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Valid: true}, dst: &tim}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} + +func TestTimestamptzMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Timestamptz + result string + }{ + {source: pgtype.Timestamptz{}, result: "null"}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29T10:05:45-06:00\""}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}, result: "\"2012-03-29T10:05:45.555-06:00\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestTimestamptzUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Timestamptz + }{ + {source: "null", result: pgtype.Timestamptz{}}, + {source: "\"2012-03-29T10:05:45-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"2012-03-29T10:05:45.555-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Valid: true}}, + {source: "\"infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Valid: true}}, + {source: "\"-infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, + } + for i, tt := range successfulTests { + var r pgtype.Timestamptz + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !r.Time.Equal(tt.result.Time) || r.Valid != tt.result.Valid || r.InfinityModifier != tt.result.InfinityModifier { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/tsrange.go b/pgtype/tsrange.go new file mode 100644 index 00000000..7495d972 --- /dev/null +++ b/pgtype/tsrange.go @@ -0,0 +1,257 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +type Tsrange struct { + Lower Timestamp + Upper Timestamp + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (dst *Tsrange) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Tsrange{} + return nil + } + + switch value := src.(type) { + case Tsrange: + *dst = value + case *Tsrange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to Tsrange", src) + } + + return nil +} + +func (src Tsrange) Get() interface{} { + if !src.Valid { + return nil + } + return src +} + +func (src *Tsrange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Tsrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tsrange{} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Tsrange{Valid: true} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tsrange{} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Tsrange{Valid: true} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Tsrange) Scan(src interface{}) error { + if src == nil { + *dst = Tsrange{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tsrange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/tsrange_array.go b/pgtype/tsrange_array.go new file mode 100644 index 00000000..2af25f8d --- /dev/null +++ b/pgtype/tsrange_array.go @@ -0,0 +1,457 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type TsrangeArray struct { + Elements []Tsrange + Dimensions []ArrayDimension + Valid bool +} + +func (dst *TsrangeArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TsrangeArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []Tsrange: + if value == nil { + *dst = TsrangeArray{} + } else if len(value) == 0 { + *dst = TsrangeArray{Valid: true} + } else { + *dst = TsrangeArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TsrangeArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for TsrangeArray", src) + } + if elementsLength == 0 { + *dst = TsrangeArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to TsrangeArray", src) + } + + *dst = TsrangeArray{ + Elements: make([]Tsrange, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Tsrange, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to TsrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *TsrangeArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to TsrangeArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in TsrangeArray", err) + } + index++ + + return index, nil +} + +func (dst TsrangeArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *TsrangeArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]Tsrange: + *v = make([]Tsrange, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *TsrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from TsrangeArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from TsrangeArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *TsrangeArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TsrangeArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Tsrange + + if len(uta.Elements) > 0 { + elements = make([]Tsrange, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Tsrange + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TsrangeArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *TsrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TsrangeArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TsrangeArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Tsrange, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TsrangeArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src TsrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src TsrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("tsrange"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "tsrange") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TsrangeArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TsrangeArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/tsrange_test.go b/pgtype/tsrange_test.go new file mode 100644 index 00000000..daea59bb --- /dev/null +++ b/pgtype/tsrange_test.go @@ -0,0 +1,41 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestTsrangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ + &pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + &pgtype.Tsrange{ + Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Timestamp{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + &pgtype.Tsrange{ + Lower: pgtype.Timestamp{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + &pgtype.Tsrange{}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Tsrange) + b := bb.(pgtype.Tsrange) + + return a.Valid == b.Valid && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Valid == b.Lower.Valid && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Valid == b.Upper.Valid && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} diff --git a/pgtype/tstzrange.go b/pgtype/tstzrange.go new file mode 100644 index 00000000..3d4e2cde --- /dev/null +++ b/pgtype/tstzrange.go @@ -0,0 +1,257 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +type Tstzrange struct { + Lower Timestamptz + Upper Timestamptz + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (dst *Tstzrange) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Tstzrange{} + return nil + } + + switch value := src.(type) { + case Tstzrange: + *dst = value + case *Tstzrange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to Tstzrange", src) + } + + return nil +} + +func (src Tstzrange) Get() interface{} { + if !src.Valid { + return nil + } + return src +} + +func (src *Tstzrange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Tstzrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tstzrange{} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Tstzrange{Valid: true} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tstzrange{} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Tstzrange{Valid: true} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Tstzrange) Scan(src interface{}) error { + if src == nil { + *dst = Tstzrange{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tstzrange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/tstzrange_array.go b/pgtype/tstzrange_array.go new file mode 100644 index 00000000..389d6b4c --- /dev/null +++ b/pgtype/tstzrange_array.go @@ -0,0 +1,457 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type TstzrangeArray struct { + Elements []Tstzrange + Dimensions []ArrayDimension + Valid bool +} + +func (dst *TstzrangeArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TstzrangeArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []Tstzrange: + if value == nil { + *dst = TstzrangeArray{} + } else if len(value) == 0 { + *dst = TstzrangeArray{Valid: true} + } else { + *dst = TstzrangeArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TstzrangeArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for TstzrangeArray", src) + } + if elementsLength == 0 { + *dst = TstzrangeArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to TstzrangeArray", src) + } + + *dst = TstzrangeArray{ + Elements: make([]Tstzrange, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Tstzrange, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to TstzrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *TstzrangeArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to TstzrangeArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in TstzrangeArray", err) + } + index++ + + return index, nil +} + +func (dst TstzrangeArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *TstzrangeArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]Tstzrange: + *v = make([]Tstzrange, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from TstzrangeArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from TstzrangeArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *TstzrangeArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TstzrangeArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Tstzrange + + if len(uta.Elements) > 0 { + elements = make([]Tstzrange, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Tstzrange + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TstzrangeArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *TstzrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TstzrangeArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TstzrangeArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Tstzrange, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TstzrangeArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src TstzrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src TstzrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("tstzrange"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "tstzrange") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TstzrangeArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TstzrangeArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/tstzrange_test.go b/pgtype/tstzrange_test.go new file mode 100644 index 00000000..49cfc63e --- /dev/null +++ b/pgtype/tstzrange_test.go @@ -0,0 +1,49 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestTstzrangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ + &pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + &pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Timestamptz{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + &pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + &pgtype.Tstzrange{}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Tstzrange) + b := bb.(pgtype.Tstzrange) + + return a.Valid == b.Valid && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Valid == b.Lower.Valid && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Valid == b.Upper.Valid && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} + +// https://github.com/jackc/pgtype/issues/74 +func TestTstzRangeDecodeTextInvalid(t *testing.T) { + tstzrange := &pgtype.Tstzrange{} + err := tstzrange.DecodeText(nil, []byte(`[eeee,)`)) + require.Error(t, err) +} diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb new file mode 100644 index 00000000..e1ead59c --- /dev/null +++ b/pgtype/typed_array.go.erb @@ -0,0 +1,481 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgio" +) + +type <%= pgtype_array_type %> struct { + Elements []<%= pgtype_element_type %> + Dimensions []ArrayDimension + Valid bool +} + +func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = <%= pgtype_array_type %>{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + <% go_array_types.split(",").each do |t| %> + <% if t != "[]#{pgtype_element_type}" %> + case <%= t %>: + if value == nil { + *dst = <%= pgtype_array_type %>{} + } else if len(value) == 0 { + *dst = <%= pgtype_array_type %>{Valid: true} + } else { + elements := make([]<%= pgtype_element_type %>, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = <%= pgtype_array_type %>{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + <% end %> + <% end %> + case []<%= pgtype_element_type %>: + if value == nil { + *dst = <%= pgtype_array_type %>{} + } else if len(value) == 0 { + *dst = <%= pgtype_array_type %>{Valid: true} + } else { + *dst = <%= pgtype_array_type %>{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = <%= pgtype_array_type %>{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for <%= pgtype_array_type %>", src) + } + if elementsLength == 0 { + *dst = <%= pgtype_array_type %>{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to <%= pgtype_array_type %>", src) + } + + *dst = <%= pgtype_array_type %> { + Elements: make([]<%= pgtype_element_type %>, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]<%= pgtype_element_type %>, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to <%= pgtype_array_type %>, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *<%= pgtype_array_type %>) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to <%= pgtype_array_type %>") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in <%= pgtype_array_type %>", err) + } + index++ + + return index, nil +} + +func (dst <%= pgtype_array_type %>) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1{ + // Attempt to match to select common types: + switch v := dst.(type) { + <% go_array_types.split(",").each do |t| %> + case *<%= t %>: + *v = make(<%= t %>, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + <% end %> + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr(){ + return 0, fmt.Errorf("cannot assign all values from <%= pgtype_array_type %>") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from <%= pgtype_array_type %>") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= pgtype_array_type %>{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []<%= pgtype_element_type %> + + if len(uta.Elements) > 0 { + elements = make([]<%= pgtype_element_type %>, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem <%= pgtype_element_type %> + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +<% if binary_format == "true" %> +func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= pgtype_array_type %>{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = <%= pgtype_array_type %>{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]<%= pgtype_element_type %>, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp:rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} +<% end %> + +func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `<%= text_null %>`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +<% if binary_format == "true" %> + func (src <%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil + } +<% end %> + +// Scan implements the database/sql Scanner interface. +func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src <%= pgtype_array_type %>) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh new file mode 100755 index 00000000..ea28be07 --- /dev/null +++ b/pgtype/typed_array_gen.sh @@ -0,0 +1,28 @@ +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time,[]*time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go +erb pgtype_array_type=TsrangeArray pgtype_element_type=Tsrange go_array_types=[]Tsrange element_type_name=tsrange text_null=NULL binary_format=true typed_array.go.erb > tsrange_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time,[]*time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32,[]*float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64,[]*float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go +erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr,[]*net.HardwareAddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go +erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string,[]*string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string,[]*string element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go +erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string,[]*string element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string,[]*string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go +erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go +erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go +erb pgtype_array_type=JSONBArray pgtype_element_type=JSONB go_array_types=[]string,[][]byte element_type_name=jsonb text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go + +# While the binary format is theoretically possible it is only practical to use the text format. +erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go + +goimports -w *_array.go diff --git a/pgtype/typed_range.go.erb b/pgtype/typed_range.go.erb new file mode 100644 index 00000000..99d8c22d --- /dev/null +++ b/pgtype/typed_range.go.erb @@ -0,0 +1,259 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgio" +) + +type <%= range_type %> struct { + Lower <%= element_type %> + Upper <%= element_type %> + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (dst *<%= range_type %>) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = <%= range_type %>{} + return nil + } + + switch value := src.(type) { + case <%= range_type %>: + *dst = value + case *<%= range_type %>: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to <%= range_type %>", src) + } + + return nil +} + +func (src <%= range_type %>) Get() interface{} { + if !src.Valid { + return nil + } + return src +} + +func (src *<%= range_type %>) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= range_type %>{} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = <%= range_type %>{Valid: true} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *<%= range_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= range_type %>{} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = <%= range_type %>{Valid: true} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *<%= range_type %>) Scan(src interface{}) error { + if src == nil { + *dst = <%= range_type %>{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src <%= range_type %>) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/typed_range_gen.sh b/pgtype/typed_range_gen.sh new file mode 100644 index 00000000..bedda292 --- /dev/null +++ b/pgtype/typed_range_gen.sh @@ -0,0 +1,7 @@ +erb range_type=Int4range element_type=Int4 typed_range.go.erb > int4range.go +erb range_type=Int8range element_type=Int8 typed_range.go.erb > int8range.go +erb range_type=Tsrange element_type=Timestamp typed_range.go.erb > tsrange.go +erb range_type=Tstzrange element_type=Timestamptz typed_range.go.erb > tstzrange.go +erb range_type=Daterange element_type=Date typed_range.go.erb > daterange.go +erb range_type=Numrange element_type=Numeric typed_range.go.erb > numrange.go +goimports -w *range.go diff --git a/pgtype/unknown.go b/pgtype/unknown.go new file mode 100644 index 00000000..0e576ee9 --- /dev/null +++ b/pgtype/unknown.go @@ -0,0 +1,44 @@ +package pgtype + +import "database/sql/driver" + +// Unknown represents the PostgreSQL unknown type. It is either a string literal +// or NULL. It is used when PostgreSQL does not know the type of a value. In +// general, this will only be used in pgx when selecting a null value without +// type information. e.g. SELECT NULL; +type Unknown struct { + String string + Valid bool +} + +func (dst *Unknown) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst Unknown) Get() interface{} { + return (Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Unknown is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Unknown) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Unknown) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Unknown) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/pgtype/uuid.go b/pgtype/uuid.go new file mode 100644 index 00000000..4533aa06 --- /dev/null +++ b/pgtype/uuid.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/hex" + "fmt" +) + +type UUID struct { + Bytes [16]byte + Valid bool + + UUIDDecoderWrapper func(interface{}) UUIDDecoder + Getter func(UUID) interface{} +} + +func (n *UUID) NewTypeValue() Value { + return &UUID{ + UUIDDecoderWrapper: n.UUIDDecoderWrapper, + Getter: n.Getter, + } +} + +func (n *UUID) TypeName() string { + return "uuid" +} + +func (dst *UUID) setNil() { + dst.Bytes = [16]byte{} + dst.Valid = false +} + +func (dst *UUID) setByteArray(value [16]byte) { + dst.Bytes = value + dst.Valid = true +} + +func (dst *UUID) setByteSlice(value []byte) error { + if value != nil { + if len(value) != 16 { + return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + } + copy(dst.Bytes[:], value) + dst.Valid = true + } else { + dst.setNil() + } + + return nil +} + +func (dst *UUID) setString(value string) error { + uuid, err := parseUUID(value) + if err != nil { + return err + } + dst.setByteArray(uuid) + return nil +} + +func (dst *UUID) Set(src interface{}) error { + if src == nil { + dst.setNil() + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case [16]byte: + dst.setByteArray(value) + case []byte: + return dst.setByteSlice(value) + case string: + return dst.setString(value) + case *string: + if value == nil { + dst.setNil() + } else { + return dst.setString(*value) + } + default: + if originalSrc, ok := underlyingUUIDType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to UUID", value) + } + + return nil +} + +func (dst UUID) Get() interface{} { + if dst.Getter != nil { + return dst.Getter(dst) + } + + if !dst.Valid { + return nil + } + + return dst.Bytes +} + +type UUIDDecoder interface { + DecodeUUID(*UUID) error +} + +func (src *UUID) AssignTo(dst interface{}) error { + if d, ok := dst.(UUIDDecoder); ok { + return d.DecodeUUID(src) + } else { + if src.UUIDDecoderWrapper != nil { + d = src.UUIDDecoderWrapper(dst) + if d != nil { + return d.DecodeUUID(src) + } + } + } + + if !src.Valid { + return NullAssignTo(dst) + } + + switch v := dst.(type) { + case *[16]byte: + *v = src.Bytes + return nil + case *[]byte: + *v = make([]byte, 16) + copy(*v, src.Bytes[:]) + return nil + case *string: + *v = encodeUUID(src.Bytes) + return nil + default: + if nextDst, retry := GetAssignToDstType(v); retry { + return src.AssignTo(nextDst) + } + } + + return nil +} + +// parseUUID converts a string UUID in standard form to a byte array. +func parseUUID(src string) (dst [16]byte, err error) { + switch len(src) { + case 36: + src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] + case 32: + // dashes already stripped, assume valid + default: + // assume invalid. + return dst, fmt.Errorf("cannot parse UUID %v", src) + } + + buf, err := hex.DecodeString(src) + if err != nil { + return dst, err + } + + copy(dst[:], buf) + return dst, err +} + +// encodeUUID converts a uuid byte array to UUID standard string form. +func encodeUUID(src [16]byte) string { + return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16]) +} + +func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + dst.setNil() + return nil + } + + if len(src) != 36 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) + } + + buf, err := parseUUID(string(src)) + if err != nil { + return err + } + + dst.setByteArray(buf) + return nil +} + +func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + dst.setNil() + return nil + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) + } + + return dst.setByteSlice(src) +} + +func (src UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, encodeUUID(src.Bytes)...), nil +} + +func (src UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + return append(buf, src.Bytes[:]...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUID) Scan(src interface{}) error { + if src == nil { + dst.setNil() + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src UUID) Value() (driver.Value, error) { + return EncodeValueText(src) +} + +func (src UUID) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(encodeUUID(src.Bytes)) + buff.WriteByte('"') + return buff.Bytes(), nil +} + +func (dst *UUID) UnmarshalJSON(src []byte) error { + if bytes.Compare(src, []byte("null")) == 0 { + return dst.Set(nil) + } + if len(src) != 38 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) + } + return dst.Set(string(src[1 : len(src)-1])) +} diff --git a/pgtype/uuid_array.go b/pgtype/uuid_array.go new file mode 100644 index 00000000..98904f9f --- /dev/null +++ b/pgtype/uuid_array.go @@ -0,0 +1,560 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type UUIDArray struct { + Elements []UUID + Dimensions []ArrayDimension + Valid bool +} + +func (dst *UUIDArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = UUIDArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case [][16]byte: + if value == nil { + *dst = UUIDArray{} + } else if len(value) == 0 { + *dst = UUIDArray{Valid: true} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case [][]byte: + if value == nil { + *dst = UUIDArray{} + } else if len(value) == 0 { + *dst = UUIDArray{Valid: true} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []string: + if value == nil { + *dst = UUIDArray{} + } else if len(value) == 0 { + *dst = UUIDArray{Valid: true} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*string: + if value == nil { + *dst = UUIDArray{} + } else if len(value) == 0 { + *dst = UUIDArray{Valid: true} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []UUID: + if value == nil { + *dst = UUIDArray{} + } else if len(value) == 0 { + *dst = UUIDArray{Valid: true} + } else { + *dst = UUIDArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = UUIDArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for UUIDArray", src) + } + if elementsLength == 0 { + *dst = UUIDArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to UUIDArray", src) + } + + *dst = UUIDArray{ + Elements: make([]UUID, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]UUID, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to UUIDArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *UUIDArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to UUIDArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in UUIDArray", err) + } + index++ + + return index, nil +} + +func (dst UUIDArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *UUIDArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[][16]byte: + *v = make([][16]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from UUIDArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from UUIDArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *UUIDArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = UUIDArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []UUID + + if len(uta.Elements) > 0 { + elements = make([]UUID, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem UUID + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = UUIDArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *UUIDArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = UUIDArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = UUIDArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]UUID, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = UUIDArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("uuid"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "uuid") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUIDArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src UUIDArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/uuid_array_test.go b/pgtype/uuid_array_test.go new file mode 100644 index 00000000..47afadff --- /dev/null +++ b/pgtype/uuid_array_test.go @@ -0,0 +1,368 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestUUIDArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid[]", []interface{}{ + &pgtype.UUIDArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.UUIDArray{}, + &pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, + {}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestUUIDArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.UUIDArray + }{ + { + source: nil, + result: pgtype.UUIDArray{}, + }, + { + source: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][16]byte{}, + result: pgtype.UUIDArray{Valid: true}, + }, + { + source: ([][16]byte)(nil), + result: pgtype.UUIDArray{}, + }, + { + source: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][]byte{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + nil, + {32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, + }, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 4}}, + Valid: true}, + }, + { + source: [][]byte{}, + result: pgtype.UUIDArray{Valid: true}, + }, + { + source: ([][]byte)(nil), + result: pgtype.UUIDArray{}, + }, + { + source: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: []string{}, + result: pgtype.UUIDArray{Valid: true}, + }, + { + source: ([]string)(nil), + result: pgtype.UUIDArray{}, + }, + { + source: [][][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.UUIDArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUUIDArrayAssignTo(t *testing.T) { + var byteArraySlice [][16]byte + var byteSliceSlice [][]byte + var stringSlice []string + var byteSlice []byte + var byteArraySliceDim2 [][][16]byte + var stringSliceDim4 [][][][]string + var byteArrayDim2 [2][1][16]byte + var stringArrayDim4 [2][1][1][3]string + + simpleTests := []struct { + src pgtype.UUIDArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &byteArraySlice, + expected: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + }, + { + src: pgtype.UUIDArray{}, + dst: &byteArraySlice, + expected: ([][16]byte)(nil), + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &byteSliceSlice, + expected: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + }, + { + src: pgtype.UUIDArray{}, + dst: &byteSliceSlice, + expected: ([][]byte)(nil), + }, + { + src: pgtype.UUIDArray{Valid: true}, + dst: &byteSlice, + expected: []byte{}, + }, + { + src: pgtype.UUIDArray{Valid: true}, + dst: &stringSlice, + expected: []string{}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &stringSlice, + expected: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, + }, + { + src: pgtype.UUIDArray{}, + dst: &stringSlice, + expected: ([]string)(nil), + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &byteArraySliceDim2, + expected: [][][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &stringSliceDim4, + expected: [][][][]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &byteArrayDim2, + expected: [2][1][16]byte{{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + {{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}}}, + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Valid: true}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Valid: true}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Valid: true}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Valid: true}, + {Bytes: [16]byte{80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}, Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{ + {{{ + "00010203-0405-0607-0809-0a0b0c0d0e0f", + "10111213-1415-1617-1819-1a1b1c1d1e1f", + "20212223-2425-2627-2829-2a2b2c2d2e2f"}}}, + {{{ + "30313233-3435-3637-3839-3a3b3c3d3e3f", + "40414243-4445-4647-4849-4a4b4c4d4e4f", + "50515253-5455-5657-5859-5a5b5c5d5e5f"}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go new file mode 100644 index 00000000..63797178 --- /dev/null +++ b/pgtype/uuid_test.go @@ -0,0 +1,229 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestUUIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + &pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + &pgtype.UUID{}, + }) +} + +type SomeUUIDWrapper struct { + SomeUUIDType +} + +type SomeUUIDType [16]byte + +func TestUUIDSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.UUID + }{ + { + source: nil, + result: pgtype.UUID{}, + }, + { + source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + }, + { + source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + }, + { + source: SomeUUIDType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + }, + { + source: ([]byte)(nil), + result: pgtype.UUID{}, + }, + { + source: "00010203-0405-0607-0809-0a0b0c0d0e0f", + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + }, + { + source: "000102030405060708090a0b0c0d0e0f", + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.UUID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r.Bytes != tt.result.Bytes || r.Valid != tt.result.Valid { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUUIDAssignTo(t *testing.T) { + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} + var dst [16]byte + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} + var dst []byte + expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare(dst, expected) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} + var dst SomeUUIDType + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} + var dst string + expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} + var dst SomeUUIDWrapper + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst.SomeUUIDType != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } +} + +func TestUUID_MarshalJSON(t *testing.T) { + tests := []struct { + name string + src pgtype.UUID + want []byte + }{ + { + name: "first", + src: pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Valid: true, + }, + want: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), + }, + { + name: "third", + src: pgtype.UUID{ + Bytes: [16]byte{}, + }, + want: []byte("null"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.src.MarshalJSON() + require.NoError(t, err) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUUID_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + want *pgtype.UUID + src []byte + wantErr bool + }{ + { + name: "first", + want: &pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Valid: true, + }, + src: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), + wantErr: false, + }, + { + name: "second", + want: &pgtype.UUID{ + Bytes: [16]byte{}, + }, + src: []byte("null"), + wantErr: false, + }, + { + name: "third", + want: &pgtype.UUID{ + Bytes: [16]byte{}, + Valid: false, + }, + src: []byte("1d485a7a-6d18-4599-8c6c-34425616887a"), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &pgtype.UUID{} + if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pgtype/varbit.go b/pgtype/varbit.go new file mode 100644 index 00000000..bc6fdac4 --- /dev/null +++ b/pgtype/varbit.go @@ -0,0 +1,123 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + + "github.com/jackc/pgio" +) + +type Varbit struct { + Bytes []byte + Len int32 // Number of bits + Valid bool +} + +func (dst *Varbit) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Varbit", src) +} + +func (dst Varbit) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *Varbit) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Varbit) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Varbit{} + return nil + } + + bitLen := len(src) + byteLen := bitLen / 8 + if bitLen%8 > 0 { + byteLen++ + } + buf := make([]byte, byteLen) + + for i, b := range src { + if b == '1' { + byteIdx := i / 8 + bitIdx := uint(i % 8) + buf[byteIdx] = buf[byteIdx] | (128 >> bitIdx) + } + } + + *dst = Varbit{Bytes: buf, Len: int32(bitLen), Valid: true} + return nil +} + +func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Varbit{} + return nil + } + + if len(src) < 4 { + return fmt.Errorf("invalid length for varbit: %v", len(src)) + } + + bitLen := int32(binary.BigEndian.Uint32(src)) + rp := 4 + + *dst = Varbit{Bytes: src[rp:], Len: bitLen, Valid: true} + return nil +} + +func (src Varbit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + for i := int32(0); i < src.Len; i++ { + byteIdx := i / 8 + bitMask := byte(128 >> byte(i%8)) + char := byte('0') + if src.Bytes[byteIdx]&bitMask > 0 { + char = '1' + } + buf = append(buf, char) + } + + return buf, nil +} + +func (src Varbit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + buf = pgio.AppendInt32(buf, src.Len) + return append(buf, src.Bytes...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Varbit) Scan(src interface{}) error { + if src == nil { + *dst = Varbit{} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Varbit) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/varbit_test.go b/pgtype/varbit_test.go new file mode 100644 index 00000000..b81bdc0e --- /dev/null +++ b/pgtype/varbit_test.go @@ -0,0 +1,26 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestVarbitTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "varbit", []interface{}{ + &pgtype.Varbit{Bytes: []byte{}, Len: 0, Valid: true}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}, + &pgtype.Varbit{}, + }) +} + +func TestVarbitNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select B'111111111'", + Value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Valid: true}, + }, + }) +} diff --git a/pgtype/varchar.go b/pgtype/varchar.go new file mode 100644 index 00000000..fea31d18 --- /dev/null +++ b/pgtype/varchar.go @@ -0,0 +1,66 @@ +package pgtype + +import ( + "database/sql/driver" +) + +type Varchar Text + +// Set converts from src to dst. Note that as Varchar is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *Varchar) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst Varchar) Get() interface{} { + return (Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Varchar is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Varchar) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (Varchar) PreferredResultFormat() int16 { + return TextFormatCode +} + +func (dst *Varchar) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (Varchar) PreferredParamFormat() int16 { + return TextFormatCode +} + +func (src Varchar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeText(ci, buf) +} + +func (src Varchar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Varchar) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Varchar) Value() (driver.Value, error) { + return (Text)(src).Value() +} + +func (src Varchar) MarshalJSON() ([]byte, error) { + return (Text)(src).MarshalJSON() +} + +func (dst *Varchar) UnmarshalJSON(b []byte) error { + return (*Text)(dst).UnmarshalJSON(b) +} diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go new file mode 100644 index 00000000..3e0913dc --- /dev/null +++ b/pgtype/varchar_array.go @@ -0,0 +1,504 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type VarcharArray struct { + Elements []Varchar + Dimensions []ArrayDimension + Valid bool +} + +func (dst *VarcharArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = VarcharArray{} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []string: + if value == nil { + *dst = VarcharArray{} + } else if len(value) == 0 { + *dst = VarcharArray{Valid: true} + } else { + elements := make([]Varchar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = VarcharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []*string: + if value == nil { + *dst = VarcharArray{} + } else if len(value) == 0 { + *dst = VarcharArray{Valid: true} + } else { + elements := make([]Varchar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = VarcharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Valid: true, + } + } + + case []Varchar: + if value == nil { + *dst = VarcharArray{} + } else if len(value) == 0 { + *dst = VarcharArray{Valid: true} + } else { + *dst = VarcharArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Valid: true, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = VarcharArray{} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for VarcharArray", src) + } + if elementsLength == 0 { + *dst = VarcharArray{Valid: true} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to VarcharArray", src) + } + + *dst = VarcharArray{ + Elements: make([]Varchar, elementsLength), + Dimensions: dimensions, + Valid: true, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Varchar, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to VarcharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *VarcharArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to VarcharArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in VarcharArray", err) + } + index++ + + return index, nil +} + +func (dst VarcharArray) Get() interface{} { + if !dst.Valid { + return nil + } + return dst +} + +func (src *VarcharArray) AssignTo(dst interface{}) error { + if !src.Valid { + return NullAssignTo(dst) + } + + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil +} + +func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from VarcharArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from VarcharArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Varchar + + if len(uta.Elements) > 0 { + elements = make([]Varchar, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Varchar + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} + + return nil +} + +func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = VarcharArray{Dimensions: arrayHeader.Dimensions, Valid: true} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Varchar, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} + return nil +} + +func (src VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + if !src.Valid { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("varchar"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "varchar") + } + + for i := range src.Elements { + if !src.Elements[i].Valid { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *VarcharArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src VarcharArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/varchar_array_test.go b/pgtype/varchar_array_test.go new file mode 100644 index 00000000..cf0efd6d --- /dev/null +++ b/pgtype/varchar_array_test.go @@ -0,0 +1,282 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestVarcharArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "varchar[]", []interface{}{ + &pgtype.VarcharArray{ + Elements: nil, + Dimensions: nil, + Valid: true, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Valid: true}, + {}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.VarcharArray{}, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "bar ", Valid: true}, + {String: "NuLL", Valid: true}, + {String: `wow"quz\`, Valid: true}, + {String: "", Valid: true}, + {}, + {String: "null", Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Valid: true, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "quz", Valid: true}, + {String: "foo", Valid: true}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Valid: true, + }, + }) +} + +func TestVarcharArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.VarcharArray + }{ + { + source: []string{"foo"}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: (([]string)(nil)), + result: pgtype.VarcharArray{}, + }, + { + source: [][]string{{"foo"}, {"bar"}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + { + source: [2][1]string{{"foo"}, {"bar"}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + }, + { + source: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.VarcharArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestVarcharArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + var stringSliceDim2 [][]string + var stringSliceDim4 [][][][]string + var stringArrayDim2 [2][1]string + var stringArrayDim4 [2][1][1][3]string + + simpleTests := []struct { + src pgtype.VarcharArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.VarcharArray{}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + { + src: pgtype.VarcharArray{Valid: true}, + dst: &stringSlice, + expected: []string{}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &stringSliceDim2, + expected: [][]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &stringSliceDim4, + expected: [][][][]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &stringArrayDim2, + expected: [2][1]string{{"foo"}, {"bar"}}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + {String: "foo", Valid: true}, + {String: "bar", Valid: true}, + {String: "baz", Valid: true}, + {String: "wibble", Valid: true}, + {String: "wobble", Valid: true}, + {String: "wubble", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{ + {LowerBound: 1, Length: 2}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 1}, + {LowerBound: 1, Length: 3}}, + Valid: true}, + dst: &stringArrayDim4, + expected: [2][1][1][3]string{{{{"foo", "bar", "baz"}}}, {{{"wibble", "wobble", "wubble"}}}}, + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.VarcharArray + dst interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Valid: true, + }, + dst: &stringSlice, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &stringArrayDim2, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, + Valid: true}, + dst: &stringSlice, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Valid: true}, {String: "bar", Valid: true}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, + Valid: true}, + dst: &stringArrayDim4, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/workflows/ci.yml b/pgtype/workflows/ci.yml new file mode 100644 index 00000000..4b5a72f2 --- /dev/null +++ b/pgtype/workflows/ci.yml @@ -0,0 +1,52 @@ +name: CI + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + + test: + name: Test + runs-on: ubuntu-latest + + services: + postgres: + image: postgres + env: + POSTGRES_PASSWORD: secret + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + + steps: + + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ^1.13 + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: Create hstore extension + run: psql -c 'create extension hstore' + env: + PGHOST: localhost + PGUSER: postgres + PGPASSWORD: secret + PGSSLMODE: disable + + - name: Test + run: go test -v ./... + env: + PGHOST: localhost + PGUSER: postgres + PGPASSWORD: secret + PGSSLMODE: disable diff --git a/pgtype/xid.go b/pgtype/xid.go new file mode 100644 index 00000000..f6d6b22d --- /dev/null +++ b/pgtype/xid.go @@ -0,0 +1,64 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// XID is PostgreSQL's Transaction ID type. +// +// In later versions of PostgreSQL, it is the type used for the backend_xid +// and backend_xmin columns of the pg_stat_activity system view. +// +// Also, when one does +// +// select xmin, xmax, * from some_table; +// +// it is the data type of the xmin and xmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/postgres_ext.h as TransactionId +// in the PostgreSQL sources. +type XID pguint32 + +// Set converts from src to dst. Note that as XID is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *XID) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst XID) Get() interface{} { + return (pguint32)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as XID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *XID) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *XID) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) +} + +func (dst *XID) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) +} + +func (src XID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeText(ci, buf) +} + +func (src XID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *XID) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src XID) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go new file mode 100644 index 00000000..fab10f79 --- /dev/null +++ b/pgtype/xid_test.go @@ -0,0 +1,102 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" +) + +func TestXIDTranscode(t *testing.T) { + pgTypeName := "xid" + values := []interface{}{ + &pgtype.XID{Uint: 42, Valid: true}, + &pgtype.XID{}, + } + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, "github.com/jackc/pgx/stdlib", pgTypeName, values, eqFunc) +} + +func TestXIDSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.XID + }{ + {source: uint32(1), result: pgtype.XID{Uint: 1, Valid: true}}, + } + + for i, tt := range successfulTests { + var r pgtype.XID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestXIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.XID + dst interface{} + expected interface{} + }{ + {src: pgtype.XID{Uint: 42, Valid: true}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.XID{}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.XID + dst interface{} + expected interface{} + }{ + {src: pgtype.XID{Uint: 42, Valid: true}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.XID + dst interface{} + }{ + {src: pgtype.XID{}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/zeronull/doc.go b/pgtype/zeronull/doc.go new file mode 100644 index 00000000..78a52307 --- /dev/null +++ b/pgtype/zeronull/doc.go @@ -0,0 +1,22 @@ +// Package zeronull contains types that automatically convert between database NULLs and Go zero values. +/* +Sometimes the distinction between a zero value and a NULL value is not useful at the application level. For example, +in PostgreSQL an empty string may be stored as NULL. There is usually no application level distinction between an +empty string and a NULL string. Package zeronull implements types that seamlessly convert between PostgreSQL NULL and +the zero value. + +It is recommended to convert types at usage time rather than instantiate these types directly. In the example below, +middlename would be stored as a NULL. + + firstname := "John" + middlename := "" + lastname := "Smith" + _, err := conn.Exec( + ctx, + "insert into people(firstname, middlename, lastname) values($1, $2, $3)", + zeronull.Text(firstname), + zeronull.Text(middlename), + zeronull.Text(lastname), + ) +*/ +package zeronull diff --git a/pgtype/zeronull/float8.go b/pgtype/zeronull/float8.go new file mode 100644 index 00000000..07d5e1a5 --- /dev/null +++ b/pgtype/zeronull/float8.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Float8 float64 + +func (dst *Float8) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Float8 + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Float8(nullable.Float) + } else { + *dst = 0 + } + + return nil +} + +func (dst *Float8) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Float8 + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Float8(nullable.Float) + } else { + *dst = 0 + } + + return nil +} + +func (src Float8) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Float8{ + Float: float64(src), + Valid: true, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Float8) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Float8{ + Float: float64(src), + Valid: true, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float8) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Float8 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Float8(nullable.Float) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float8) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/pgtype/zeronull/float8_test.go b/pgtype/zeronull/float8_test.go new file mode 100644 index 00000000..27fb785e --- /dev/null +++ b/pgtype/zeronull/float8_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestFloat8Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ + (zeronull.Float8)(1), + (zeronull.Float8)(0), + }) +} + +func TestFloat8ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "float8", (zeronull.Float8)(0)) +} + +func TestFloat8ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "float8", (zeronull.Float8)(0)) +} diff --git a/pgtype/zeronull/int2.go b/pgtype/zeronull/int2.go new file mode 100644 index 00000000..b3f9c328 --- /dev/null +++ b/pgtype/zeronull/int2.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Int2 int16 + +func (dst *Int2) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int2 + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Int2(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (dst *Int2) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int2 + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Int2(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (src Int2) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int2{ + Int: int16(src), + Valid: true, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Int2) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int2{ + Int: int16(src), + Valid: true, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int2 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int2(nullable.Int) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/pgtype/zeronull/int2_test.go b/pgtype/zeronull/int2_test.go new file mode 100644 index 00000000..2dcb4e79 --- /dev/null +++ b/pgtype/zeronull/int2_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestInt2Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ + (zeronull.Int2)(1), + (zeronull.Int2)(0), + }) +} + +func TestInt2ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int2", (zeronull.Int2)(0)) +} + +func TestInt2ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int2", (zeronull.Int2)(0)) +} diff --git a/pgtype/zeronull/int4.go b/pgtype/zeronull/int4.go new file mode 100644 index 00000000..3efca4e6 --- /dev/null +++ b/pgtype/zeronull/int4.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Int4 int32 + +func (dst *Int4) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int4 + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Int4(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (dst *Int4) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int4 + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Int4(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (src Int4) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int4{ + Int: int32(src), + Valid: true, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Int4) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int4{ + Int: int32(src), + Valid: true, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int4 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int4(nullable.Int) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/pgtype/zeronull/int4_test.go b/pgtype/zeronull/int4_test.go new file mode 100644 index 00000000..309e4125 --- /dev/null +++ b/pgtype/zeronull/int4_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestInt4Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ + (zeronull.Int4)(1), + (zeronull.Int4)(0), + }) +} + +func TestInt4ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int4", (zeronull.Int4)(0)) +} + +func TestInt4ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int4", (zeronull.Int4)(0)) +} diff --git a/pgtype/zeronull/int8.go b/pgtype/zeronull/int8.go new file mode 100644 index 00000000..5cb063d8 --- /dev/null +++ b/pgtype/zeronull/int8.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Int8 int64 + +func (dst *Int8) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int8 + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Int8(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (dst *Int8) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int8 + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Int8(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (src Int8) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int8{ + Int: int64(src), + Valid: true, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Int8) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int8{ + Int: int64(src), + Valid: true, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int8 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int8(nullable.Int) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/pgtype/zeronull/int8_test.go b/pgtype/zeronull/int8_test.go new file mode 100644 index 00000000..ae80bc0a --- /dev/null +++ b/pgtype/zeronull/int8_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestInt8Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ + (zeronull.Int8)(1), + (zeronull.Int8)(0), + }) +} + +func TestInt8ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int8", (zeronull.Int8)(0)) +} + +func TestInt8ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int8", (zeronull.Int8)(0)) +} diff --git a/pgtype/zeronull/text.go b/pgtype/zeronull/text.go new file mode 100644 index 00000000..afcb1a42 --- /dev/null +++ b/pgtype/zeronull/text.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Text string + +func (dst *Text) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Text + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Text(nullable.String) + } else { + *dst = Text("") + } + + return nil +} + +func (dst *Text) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Text + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Text(nullable.String) + } else { + *dst = Text("") + } + + return nil +} + +func (src Text) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == Text("") { + return nil, nil + } + + nullable := pgtype.Text{ + String: string(src), + Valid: true, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Text) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == Text("") { + return nil, nil + } + + nullable := pgtype.Text{ + String: string(src), + Valid: true, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Text) Scan(src interface{}) error { + if src == nil { + *dst = Text("") + return nil + } + + var nullable pgtype.Text + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Text(nullable.String) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Text) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/pgtype/zeronull/text_test.go b/pgtype/zeronull/text_test.go new file mode 100644 index 00000000..f08a0d2a --- /dev/null +++ b/pgtype/zeronull/text_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestTextTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "text", []interface{}{ + (zeronull.Text)("foo"), + (zeronull.Text)(""), + }) +} + +func TestTextConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "text", (zeronull.Text)("")) +} + +func TestTextConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "text", (zeronull.Text)("")) +} diff --git a/pgtype/zeronull/timestamp.go b/pgtype/zeronull/timestamp.go new file mode 100644 index 00000000..61787818 --- /dev/null +++ b/pgtype/zeronull/timestamp.go @@ -0,0 +1,91 @@ +package zeronull + +import ( + "database/sql/driver" + "time" + + "github.com/jackc/pgtype" +) + +type Timestamp time.Time + +func (dst *Timestamp) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Timestamp + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Timestamp(nullable.Time) + } else { + *dst = Timestamp{} + } + + return nil +} + +func (dst *Timestamp) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Timestamp + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Timestamp(nullable.Time) + } else { + *dst = Timestamp{} + } + + return nil +} + +func (src Timestamp) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == Timestamp{}) { + return nil, nil + } + + nullable := pgtype.Timestamp{ + Time: time.Time(src), + Valid: true, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Timestamp) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == Timestamp{}) { + return nil, nil + } + + nullable := pgtype.Timestamp{ + Time: time.Time(src), + Valid: true, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamp) Scan(src interface{}) error { + if src == nil { + *dst = Timestamp{} + return nil + } + + var nullable pgtype.Timestamp + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Timestamp(nullable.Time) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamp) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/pgtype/zeronull/timestamp_test.go b/pgtype/zeronull/timestamp_test.go new file mode 100644 index 00000000..ec96ff07 --- /dev/null +++ b/pgtype/zeronull/timestamp_test.go @@ -0,0 +1,29 @@ +package zeronull_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestTimestampTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ + (zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + (zeronull.Timestamp)(time.Time{}), + }, func(a, b interface{}) bool { + at := a.(zeronull.Timestamp) + bt := b.(zeronull.Timestamp) + + return time.Time(at).Equal(time.Time(bt)) + }) +} + +func TestTimestampConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "timestamp", (zeronull.Timestamp)(time.Time{})) +} + +func TestTimestampConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "timestamp", (zeronull.Timestamp)(time.Time{})) +} diff --git a/pgtype/zeronull/timestamptz.go b/pgtype/zeronull/timestamptz.go new file mode 100644 index 00000000..4896e9b7 --- /dev/null +++ b/pgtype/zeronull/timestamptz.go @@ -0,0 +1,91 @@ +package zeronull + +import ( + "database/sql/driver" + "time" + + "github.com/jackc/pgtype" +) + +type Timestamptz time.Time + +func (dst *Timestamptz) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Timestamptz + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Timestamptz(nullable.Time) + } else { + *dst = Timestamptz{} + } + + return nil +} + +func (dst *Timestamptz) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Timestamptz + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = Timestamptz(nullable.Time) + } else { + *dst = Timestamptz{} + } + + return nil +} + +func (src Timestamptz) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == Timestamptz{}) { + return nil, nil + } + + nullable := pgtype.Timestamptz{ + Time: time.Time(src), + Valid: true, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Timestamptz) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == Timestamptz{}) { + return nil, nil + } + + nullable := pgtype.Timestamptz{ + Time: time.Time(src), + Valid: true, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamptz) Scan(src interface{}) error { + if src == nil { + *dst = Timestamptz{} + return nil + } + + var nullable pgtype.Timestamptz + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Timestamptz(nullable.Time) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamptz) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/pgtype/zeronull/timestamptz_test.go b/pgtype/zeronull/timestamptz_test.go new file mode 100644 index 00000000..3a401c49 --- /dev/null +++ b/pgtype/zeronull/timestamptz_test.go @@ -0,0 +1,29 @@ +package zeronull_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestTimestamptzTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ + (zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + (zeronull.Timestamptz)(time.Time{}), + }, func(a, b interface{}) bool { + at := a.(zeronull.Timestamptz) + bt := b.(zeronull.Timestamptz) + + return time.Time(at).Equal(time.Time(bt)) + }) +} + +func TestTimestamptzConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "timestamptz", (zeronull.Timestamptz)(time.Time{})) +} + +func TestTimestamptzConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "timestamptz", (zeronull.Timestamptz)(time.Time{})) +} diff --git a/pgtype/zeronull/uuid.go b/pgtype/zeronull/uuid.go new file mode 100644 index 00000000..25211122 --- /dev/null +++ b/pgtype/zeronull/uuid.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type UUID [16]byte + +func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.UUID + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = UUID(nullable.Bytes) + } else { + *dst = UUID{} + } + + return nil +} + +func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.UUID + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Valid { + *dst = UUID(nullable.Bytes) + } else { + *dst = UUID{} + } + + return nil +} + +func (src UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == UUID{}) { + return nil, nil + } + + nullable := pgtype.UUID{ + Bytes: [16]byte(src), + Valid: true, + } + + return nullable.EncodeText(ci, buf) +} + +func (src UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == UUID{}) { + return nil, nil + } + + nullable := pgtype.UUID{ + Bytes: [16]byte(src), + Valid: true, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUID) Scan(src interface{}) error { + if src == nil { + *dst = UUID{} + return nil + } + + var nullable pgtype.UUID + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = UUID(nullable.Bytes) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src UUID) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/pgtype/zeronull/uuid_test.go b/pgtype/zeronull/uuid_test.go new file mode 100644 index 00000000..162bdf1f --- /dev/null +++ b/pgtype/zeronull/uuid_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestUUIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + (*zeronull.UUID)(&[16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), + (*zeronull.UUID)(&[16]byte{}), + }) +} + +func TestUUIDConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "uuid", (*zeronull.UUID)(&[16]byte{})) +} + +func TestUUIDConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "uuid", (*zeronull.UUID)(&[16]byte{})) +} diff --git a/pgtype/zzz.aclitem.go b/pgtype/zzz.aclitem.go new file mode 100644 index 00000000..6ac1f94a --- /dev/null +++ b/pgtype/zzz.aclitem.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (ACLItem) BinaryFormatSupported() bool { + return true +} + +func (ACLItem) TextFormatSupported() bool { + return true +} + +func (ACLItem) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *ACLItem) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return fmt.Errorf("binary format not supported for %T", dst) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src ACLItem) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return nil, fmt.Errorf("binary format not supported for %T", src) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.bit.go b/pgtype/zzz.bit.go new file mode 100644 index 00000000..e95df74d --- /dev/null +++ b/pgtype/zzz.bit.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Bit) BinaryFormatSupported() bool { + return true +} + +func (Bit) TextFormatSupported() bool { + return true +} + +func (Bit) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Bit) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Bit) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.bool.go b/pgtype/zzz.bool.go new file mode 100644 index 00000000..e6ed52de --- /dev/null +++ b/pgtype/zzz.bool.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Bool) BinaryFormatSupported() bool { + return true +} + +func (Bool) TextFormatSupported() bool { + return true +} + +func (Bool) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Bool) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Bool) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.box.go b/pgtype/zzz.box.go new file mode 100644 index 00000000..5ca2df43 --- /dev/null +++ b/pgtype/zzz.box.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Box) BinaryFormatSupported() bool { + return true +} + +func (Box) TextFormatSupported() bool { + return true +} + +func (Box) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Box) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Box) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.bpchar.go b/pgtype/zzz.bpchar.go new file mode 100644 index 00000000..c3178670 --- /dev/null +++ b/pgtype/zzz.bpchar.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (BPChar) BinaryFormatSupported() bool { + return true +} + +func (BPChar) TextFormatSupported() bool { + return true +} + +func (BPChar) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *BPChar) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src BPChar) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.bytea.go b/pgtype/zzz.bytea.go new file mode 100644 index 00000000..4da5ad4f --- /dev/null +++ b/pgtype/zzz.bytea.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Bytea) BinaryFormatSupported() bool { + return true +} + +func (Bytea) TextFormatSupported() bool { + return true +} + +func (Bytea) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Bytea) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Bytea) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.cid.go b/pgtype/zzz.cid.go new file mode 100644 index 00000000..4cb9671d --- /dev/null +++ b/pgtype/zzz.cid.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (CID) BinaryFormatSupported() bool { + return true +} + +func (CID) TextFormatSupported() bool { + return true +} + +func (CID) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *CID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src CID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.cidr.go b/pgtype/zzz.cidr.go new file mode 100644 index 00000000..714908e0 --- /dev/null +++ b/pgtype/zzz.cidr.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (CIDR) BinaryFormatSupported() bool { + return true +} + +func (CIDR) TextFormatSupported() bool { + return true +} + +func (CIDR) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *CIDR) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src CIDR) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.circle.go b/pgtype/zzz.circle.go new file mode 100644 index 00000000..b111c06d --- /dev/null +++ b/pgtype/zzz.circle.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Circle) BinaryFormatSupported() bool { + return true +} + +func (Circle) TextFormatSupported() bool { + return true +} + +func (Circle) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Circle) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Circle) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.date.go b/pgtype/zzz.date.go new file mode 100644 index 00000000..66132082 --- /dev/null +++ b/pgtype/zzz.date.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Date) BinaryFormatSupported() bool { + return true +} + +func (Date) TextFormatSupported() bool { + return true +} + +func (Date) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Date) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Date) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.float4.go b/pgtype/zzz.float4.go new file mode 100644 index 00000000..b600805e --- /dev/null +++ b/pgtype/zzz.float4.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Float4) BinaryFormatSupported() bool { + return true +} + +func (Float4) TextFormatSupported() bool { + return true +} + +func (Float4) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Float4) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Float4) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.float8.go b/pgtype/zzz.float8.go new file mode 100644 index 00000000..dd3ba0fa --- /dev/null +++ b/pgtype/zzz.float8.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Float8) BinaryFormatSupported() bool { + return true +} + +func (Float8) TextFormatSupported() bool { + return true +} + +func (Float8) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Float8) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Float8) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.generic_binary.go b/pgtype/zzz.generic_binary.go new file mode 100644 index 00000000..b50f1f45 --- /dev/null +++ b/pgtype/zzz.generic_binary.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (GenericBinary) BinaryFormatSupported() bool { + return true +} + +func (GenericBinary) TextFormatSupported() bool { + return true +} + +func (GenericBinary) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *GenericBinary) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return fmt.Errorf("text format not supported for %T", dst) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src GenericBinary) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return nil, fmt.Errorf("text format not supported for %T", src) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.generic_text.go b/pgtype/zzz.generic_text.go new file mode 100644 index 00000000..5ab771cf --- /dev/null +++ b/pgtype/zzz.generic_text.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (GenericText) BinaryFormatSupported() bool { + return true +} + +func (GenericText) TextFormatSupported() bool { + return true +} + +func (GenericText) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *GenericText) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return fmt.Errorf("binary format not supported for %T", dst) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src GenericText) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return nil, fmt.Errorf("binary format not supported for %T", src) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.hstore.go b/pgtype/zzz.hstore.go new file mode 100644 index 00000000..ebd7bdee --- /dev/null +++ b/pgtype/zzz.hstore.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Hstore) BinaryFormatSupported() bool { + return true +} + +func (Hstore) TextFormatSupported() bool { + return true +} + +func (Hstore) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Hstore) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Hstore) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.inet.go b/pgtype/zzz.inet.go new file mode 100644 index 00000000..51daeee6 --- /dev/null +++ b/pgtype/zzz.inet.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Inet) BinaryFormatSupported() bool { + return true +} + +func (Inet) TextFormatSupported() bool { + return true +} + +func (Inet) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Inet) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Inet) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.int2.go b/pgtype/zzz.int2.go new file mode 100644 index 00000000..f2d959f9 --- /dev/null +++ b/pgtype/zzz.int2.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Int2) BinaryFormatSupported() bool { + return true +} + +func (Int2) TextFormatSupported() bool { + return true +} + +func (Int2) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Int2) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Int2) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.int4.go b/pgtype/zzz.int4.go new file mode 100644 index 00000000..bd7f9bda --- /dev/null +++ b/pgtype/zzz.int4.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Int4) BinaryFormatSupported() bool { + return true +} + +func (Int4) TextFormatSupported() bool { + return true +} + +func (Int4) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Int4) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Int4) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.int8.go b/pgtype/zzz.int8.go new file mode 100644 index 00000000..d6e98262 --- /dev/null +++ b/pgtype/zzz.int8.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Int8) BinaryFormatSupported() bool { + return true +} + +func (Int8) TextFormatSupported() bool { + return true +} + +func (Int8) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Int8) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Int8) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.interval.go b/pgtype/zzz.interval.go new file mode 100644 index 00000000..a34f2d59 --- /dev/null +++ b/pgtype/zzz.interval.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Interval) BinaryFormatSupported() bool { + return true +} + +func (Interval) TextFormatSupported() bool { + return true +} + +func (Interval) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Interval) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Interval) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.json.go b/pgtype/zzz.json.go new file mode 100644 index 00000000..40a736c9 --- /dev/null +++ b/pgtype/zzz.json.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (JSON) BinaryFormatSupported() bool { + return true +} + +func (JSON) TextFormatSupported() bool { + return true +} + +func (JSON) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *JSON) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src JSON) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.jsonb.go b/pgtype/zzz.jsonb.go new file mode 100644 index 00000000..a07934b7 --- /dev/null +++ b/pgtype/zzz.jsonb.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (JSONB) BinaryFormatSupported() bool { + return true +} + +func (JSONB) TextFormatSupported() bool { + return true +} + +func (JSONB) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *JSONB) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src JSONB) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.line.go b/pgtype/zzz.line.go new file mode 100644 index 00000000..7365744b --- /dev/null +++ b/pgtype/zzz.line.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Line) BinaryFormatSupported() bool { + return true +} + +func (Line) TextFormatSupported() bool { + return true +} + +func (Line) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Line) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Line) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.lseg.go b/pgtype/zzz.lseg.go new file mode 100644 index 00000000..1a95af09 --- /dev/null +++ b/pgtype/zzz.lseg.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Lseg) BinaryFormatSupported() bool { + return true +} + +func (Lseg) TextFormatSupported() bool { + return true +} + +func (Lseg) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Lseg) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Lseg) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.macadder.go b/pgtype/zzz.macadder.go new file mode 100644 index 00000000..5758d68f --- /dev/null +++ b/pgtype/zzz.macadder.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Macaddr) BinaryFormatSupported() bool { + return true +} + +func (Macaddr) TextFormatSupported() bool { + return true +} + +func (Macaddr) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Macaddr) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Macaddr) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.name.go b/pgtype/zzz.name.go new file mode 100644 index 00000000..6949c337 --- /dev/null +++ b/pgtype/zzz.name.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Name) BinaryFormatSupported() bool { + return true +} + +func (Name) TextFormatSupported() bool { + return true +} + +func (Name) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Name) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Name) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.numeric.go b/pgtype/zzz.numeric.go new file mode 100644 index 00000000..838bed40 --- /dev/null +++ b/pgtype/zzz.numeric.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Numeric) BinaryFormatSupported() bool { + return true +} + +func (Numeric) TextFormatSupported() bool { + return true +} + +func (Numeric) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Numeric) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Numeric) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.oid.go b/pgtype/zzz.oid.go new file mode 100644 index 00000000..bc3ba7d2 --- /dev/null +++ b/pgtype/zzz.oid.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (OID) BinaryFormatSupported() bool { + return true +} + +func (OID) TextFormatSupported() bool { + return true +} + +func (OID) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *OID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src OID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.oid_value.go b/pgtype/zzz.oid_value.go new file mode 100644 index 00000000..6fba9e44 --- /dev/null +++ b/pgtype/zzz.oid_value.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (OIDValue) BinaryFormatSupported() bool { + return true +} + +func (OIDValue) TextFormatSupported() bool { + return true +} + +func (OIDValue) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *OIDValue) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src OIDValue) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.path.go b/pgtype/zzz.path.go new file mode 100644 index 00000000..d761ac40 --- /dev/null +++ b/pgtype/zzz.path.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Path) BinaryFormatSupported() bool { + return true +} + +func (Path) TextFormatSupported() bool { + return true +} + +func (Path) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Path) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Path) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.pguint32.go b/pgtype/zzz.pguint32.go new file mode 100644 index 00000000..c869da8f --- /dev/null +++ b/pgtype/zzz.pguint32.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (pguint32) BinaryFormatSupported() bool { + return true +} + +func (pguint32) TextFormatSupported() bool { + return true +} + +func (pguint32) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *pguint32) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src pguint32) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.point.go b/pgtype/zzz.point.go new file mode 100644 index 00000000..083ded95 --- /dev/null +++ b/pgtype/zzz.point.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Point) BinaryFormatSupported() bool { + return true +} + +func (Point) TextFormatSupported() bool { + return true +} + +func (Point) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Point) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Point) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.polygon.go b/pgtype/zzz.polygon.go new file mode 100644 index 00000000..2bfdbbd4 --- /dev/null +++ b/pgtype/zzz.polygon.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Polygon) BinaryFormatSupported() bool { + return true +} + +func (Polygon) TextFormatSupported() bool { + return true +} + +func (Polygon) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Polygon) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Polygon) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.qchar.go b/pgtype/zzz.qchar.go new file mode 100644 index 00000000..adc0f462 --- /dev/null +++ b/pgtype/zzz.qchar.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (QChar) BinaryFormatSupported() bool { + return true +} + +func (QChar) TextFormatSupported() bool { + return true +} + +func (QChar) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *QChar) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return fmt.Errorf("text format not supported for %T", dst) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src QChar) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return nil, fmt.Errorf("text format not supported for %T", src) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.text.go b/pgtype/zzz.text.go new file mode 100644 index 00000000..e1a3908f --- /dev/null +++ b/pgtype/zzz.text.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Text) BinaryFormatSupported() bool { + return true +} + +func (Text) TextFormatSupported() bool { + return true +} + +func (Text) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *Text) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Text) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.tid.go b/pgtype/zzz.tid.go new file mode 100644 index 00000000..1a705277 --- /dev/null +++ b/pgtype/zzz.tid.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (TID) BinaryFormatSupported() bool { + return true +} + +func (TID) TextFormatSupported() bool { + return true +} + +func (TID) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *TID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src TID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.time.go b/pgtype/zzz.time.go new file mode 100644 index 00000000..be9a96a7 --- /dev/null +++ b/pgtype/zzz.time.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Time) BinaryFormatSupported() bool { + return true +} + +func (Time) TextFormatSupported() bool { + return true +} + +func (Time) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Time) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Time) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.timestamp.go b/pgtype/zzz.timestamp.go new file mode 100644 index 00000000..ce6135c7 --- /dev/null +++ b/pgtype/zzz.timestamp.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Timestamp) BinaryFormatSupported() bool { + return true +} + +func (Timestamp) TextFormatSupported() bool { + return true +} + +func (Timestamp) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Timestamp) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Timestamp) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.timestamptz.go b/pgtype/zzz.timestamptz.go new file mode 100644 index 00000000..1147b257 --- /dev/null +++ b/pgtype/zzz.timestamptz.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Timestamptz) BinaryFormatSupported() bool { + return true +} + +func (Timestamptz) TextFormatSupported() bool { + return true +} + +func (Timestamptz) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Timestamptz) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Timestamptz) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.uuid.go b/pgtype/zzz.uuid.go new file mode 100644 index 00000000..a0aefaf6 --- /dev/null +++ b/pgtype/zzz.uuid.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (UUID) BinaryFormatSupported() bool { + return true +} + +func (UUID) TextFormatSupported() bool { + return true +} + +func (UUID) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *UUID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src UUID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.varbit.go b/pgtype/zzz.varbit.go new file mode 100644 index 00000000..2b090ebf --- /dev/null +++ b/pgtype/zzz.varbit.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Varbit) BinaryFormatSupported() bool { + return true +} + +func (Varbit) TextFormatSupported() bool { + return true +} + +func (Varbit) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *Varbit) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Varbit) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.varchar.go b/pgtype/zzz.varchar.go new file mode 100644 index 00000000..9771d412 --- /dev/null +++ b/pgtype/zzz.varchar.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (Varchar) BinaryFormatSupported() bool { + return true +} + +func (Varchar) TextFormatSupported() bool { + return true +} + +func (Varchar) PreferredFormat() int16 { + return TextFormatCode +} + +func (dst *Varchar) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src Varchar) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +} diff --git a/pgtype/zzz.xid.go b/pgtype/zzz.xid.go new file mode 100644 index 00000000..2754d98e --- /dev/null +++ b/pgtype/zzz.xid.go @@ -0,0 +1,35 @@ +package pgtype + +import "fmt" + +func (XID) BinaryFormatSupported() bool { + return true +} + +func (XID) TextFormatSupported() bool { + return true +} + +func (XID) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (dst *XID) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { + switch format { + case BinaryFormatCode: + return dst.DecodeBinary(ci, src) + case TextFormatCode: + return dst.DecodeText(ci, src) + } + return fmt.Errorf("unknown format code %d", format) +} + +func (src XID) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { + switch format { + case BinaryFormatCode: + return src.EncodeBinary(ci, buf) + case TextFormatCode: + return src.EncodeText(ci, buf) + } + return nil, fmt.Errorf("unknown format code %d", format) +}