From ac2414449c1bc2ddf447d256da3c9fa0ad4f5c36 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 15:33:34 -0600 Subject: [PATCH] Initial proof-of-concept for pgtype Squashed commit of the following: commit c19454582b335ce5bdda6320f7e4e8c76cfeaf44 Author: Jack Christensen Date: Fri Mar 3 15:24:47 2017 -0600 Add AssignTo to pgtype.Timestamptz Also handle infinity for pgtype.Date commit 7329933610b38f4bc15731b1f7c55c520b49e300 Author: Jack Christensen Date: Fri Mar 3 15:12:18 2017 -0600 Implement AssignTo for most pgtypes commit cc3d1e4af896d34ec98c3bf2e982d0367451f21c Author: Jack Christensen Date: Thu Mar 2 21:19:07 2017 -0600 Use pgtype.Int2Array in pgx commit 36da5cc2178d1a31a56dc6e6f128843bd80dea0b Author: Jack Christensen Date: Tue Feb 28 21:45:33 2017 -0600 Add text array transcoding commit 1b0f18d99f38b69f8c2db26388815e67b2b03d59 Author: Jack Christensen Date: Mon Feb 27 19:28:55 2017 -0600 Add ParseUntypedTextArray commit 0f50ce3e833fc38495d333228daf04f5142be676 Author: Jack Christensen Date: Mon Feb 27 18:54:20 2017 -0600 wip commit d934f273627d79997035c282416db922f2fbe87a Author: Jack Christensen Date: Sun Feb 26 17:14:32 2017 -0600 WIP - beginning text format array parsing commit 7276ad33ce7fa9c250745a3ed909998f3dae4a32 Author: Jack Christensen Date: Sat Feb 25 22:50:11 2017 -0600 Beginning binary arrays commit 917faa5a3175d376222423c10aca297a20f96448 Author: Jack Christensen Date: Sat Feb 25 19:36:35 2017 -0600 Fix incomplete tests commit de8c140cfb98b7b047d53c5718ccbf12eaf813a1 Author: Jack Christensen Date: Sat Feb 25 19:32:22 2017 -0600 Add timestamptz null and infinity commit 7d9f954de4e071a1eccac762248079b90dbeb53f Author: Jack Christensen Date: Sat Feb 25 18:19:38 2017 -0600 Add infinity to pgtype.Date commit 7bf783ae20ba05571c2fb9f661183233c95eab41 Author: Jack Christensen Date: Sat Feb 25 17:19:55 2017 -0600 Add Status to pgtype.Date commit 984500455c9b9a4b6221758540d248e6410d93a4 Author: Jack Christensen Date: Sat Feb 25 16:54:01 2017 -0600 Add status to Int4 and Int8 commit 6fe76fcfc2de31552790db3b093480a9d5b2a742 Author: Jack Christensen Date: Sat Feb 25 16:40:27 2017 -0600 Extract testSuccessfulTranscode commit 001647c1da03f796014cf21f41c9a7fd2cfadfde Author: Jack Christensen Date: Sat Feb 25 16:15:51 2017 -0600 Add Status to pgtype.Int2 commit 720451f06d13d9c9fa2a0482e010f24bf4627c2a Author: Jack Christensen Date: Sat Feb 25 15:56:44 2017 -0600 Add status to pgtype.Bool commit 325f700b6edff215a692b10bc5b94cdfe1100769 Author: Jack Christensen Date: Fri Feb 24 17:28:15 2017 -0600 Add date to conversion system commit 4a9343e45d3897f59eab98a0009d2ddbe07e02d7 Author: Jack Christensen Date: Fri Feb 24 16:28:35 2017 -0600 Add bool to oid based encoding commit d984fcafab1476cf84852485b6711f4b2069eb6d Author: Jack Christensen Date: Fri Feb 24 16:15:38 2017 -0600 Add pgtype interfaces commit 0f93bfc2de4023b069b966c0988bf7f0469d1809 Author: Jack Christensen Date: Fri Feb 24 14:48:34 2017 -0600 Begin introduction of Convert commit e5707023cac7c07342b8c910e480d09a1caaf6ee Author: Jack Christensen Date: Fri Feb 24 14:10:56 2017 -0600 Move bool to pgtype commit bb764d2129efe7fb21e841dbb35e6d0dc7586d37 Author: Jack Christensen Date: Fri Feb 24 13:45:05 2017 -0600 Add Int2 test commit 08c49437f455a32f7c3f0a524cd21a895d440301 Author: Jack Christensen Date: Fri Feb 24 13:44:09 2017 -0600 Add Int4 test commit 16722952222fd15c53c8fa84974645504a6d0dc0 Author: Jack Christensen Date: Fri Feb 24 08:56:59 2017 -0600 Add int8 tests commit 83a5447cd2c46b58d0880023cc4e9af0c84988a2 Author: Jack Christensen Date: Wed Feb 22 18:08:05 2017 -0600 wip commit 0ca0ee72068a72b016729b01fccef22474595285 Author: Jack Christensen Date: Mon Feb 20 18:56:52 2017 -0600 wip commit d2c2baf4ea2cd0793d68c7094c425217df952bec Author: Jack Christensen Date: Mon Feb 20 18:46:10 2017 -0600 wip commit f78371da0098356527b193fd496a338da5fe414b Author: Jack Christensen Date: Mon Feb 20 17:43:39 2017 -0600 wip commit 3366699bea62ec0110db05f3cb2998d58ac9ce5d Author: Jack Christensen Date: Mon Feb 20 14:07:47 2017 -0600 wip commit 66b79e940870fd0133ebb10ac1547e1d4d7d0b51 Author: Jack Christensen Date: Mon Feb 20 13:35:37 2017 -0600 Extract pgio commit 8b07d97d1305ed98fd76db6e306a289c0af92d56 Author: Jack Christensen Date: Mon Feb 20 13:20:00 2017 -0600 wip commit 62f1adb3427f4317b708da075dce50c4d4daff7b Author: Jack Christensen Date: Mon Feb 20 12:08:46 2017 -0600 wip commit a712d2546933a5a8433c65eef0ff2ee135077c87 Author: Jack Christensen Date: Mon Feb 20 09:30:52 2017 -0600 wip commit 4faf97cc588126dda160fc360680719572a23105 Author: Jack Christensen Date: Fri Feb 17 22:20:18 2017 -0600 wip --- array.go | 375 ++++++++++++++++++++++++++++++++++++++++++++ array_test.go | 98 ++++++++++++ bool.go | 166 ++++++++++++++++++++ bool_test.go | 43 +++++ convert.go | 239 ++++++++++++++++++++++++++++ date.go | 191 ++++++++++++++++++++++ date_test.go | 51 ++++++ extra-interface.txt | 3 + int2.go | 167 ++++++++++++++++++++ int2_test.go | 55 +++++++ int2array.go | 308 ++++++++++++++++++++++++++++++++++++ int2array_test.go | 87 ++++++++++ int4.go | 158 +++++++++++++++++++ int4_test.go | 55 +++++++ int8.go | 149 ++++++++++++++++++ int8_test.go | 55 +++++++ pgtype.go | 102 ++++++++++++ pgtype_test.go | 108 +++++++++++++ text_element.go | 112 +++++++++++++ timestamptz.go | 203 ++++++++++++++++++++++++ timestamptz_test.go | 60 +++++++ 21 files changed, 2785 insertions(+) create mode 100644 array.go create mode 100644 array_test.go create mode 100644 bool.go create mode 100644 bool_test.go create mode 100644 convert.go create mode 100644 date.go create mode 100644 date_test.go create mode 100644 extra-interface.txt create mode 100644 int2.go create mode 100644 int2_test.go create mode 100644 int2array.go create mode 100644 int2array_test.go create mode 100644 int4.go create mode 100644 int4_test.go create mode 100644 int8.go create mode 100644 int8_test.go create mode 100644 pgtype.go create mode 100644 pgtype_test.go create mode 100644 text_element.go create mode 100644 timestamptz.go create mode 100644 timestamptz_test.go diff --git a/array.go b/array.go new file mode 100644 index 00000000..75d2e440 --- /dev/null +++ b/array.go @@ -0,0 +1,375 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + "strconv" + "unicode" + + "github.com/jackc/pgx/pgio" +) + +// Information on the internals of PostgreSQL arrays can be found in +// src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of +// particular interest is the array_send function. + +type ArrayHeader struct { + ContainsNull bool + ElementOID int32 + Dimensions []ArrayDimension +} + +type ArrayDimension struct { + Length int32 + LowerBound int32 +} + +func (ah *ArrayHeader) DecodeBinary(r io.Reader) error { + numDims, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if numDims > 0 { + ah.Dimensions = make([]ArrayDimension, numDims) + } + + containsNull, err := pgio.ReadInt32(r) + if err != nil { + return err + } + ah.ContainsNull = containsNull == 1 + + ah.ElementOID, err = pgio.ReadInt32(r) + if err != nil { + return err + } + + for i := range ah.Dimensions { + ah.Dimensions[i].Length, err = pgio.ReadInt32(r) + if err != nil { + return err + } + + ah.Dimensions[i].LowerBound, err = pgio.ReadInt32(r) + if err != nil { + return err + } + } + + return nil +} + +func (ah *ArrayHeader) EncodeBinary(w io.Writer) error { + _, err := pgio.WriteInt32(w, int32(len(ah.Dimensions))) + if err != nil { + return err + } + + var containsNull int32 + if ah.ContainsNull { + containsNull = 1 + } + _, err = pgio.WriteInt32(w, containsNull) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, ah.ElementOID) + if err != nil { + return err + } + + for i := range ah.Dimensions { + _, err = pgio.WriteInt32(w, ah.Dimensions[i].Length) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, ah.Dimensions[i].LowerBound) + if err != nil { + return err + } + } + + return nil +} + +type UntypedTextArray struct { + Elements []string + Dimensions []ArrayDimension +} + +func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { + uta := &UntypedTextArray{} + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + var explicitDimensions []ArrayDimension + + // Array has explicit dimensions + if r == '[' { + buf.UnreadRune() + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '=' { + break + } else if r != '[' { + return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r) + } + + lower, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ':' { + return nil, fmt.Errorf("invalid array, expected ':' got %v", r) + } + + upper, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ']' { + return nil, fmt.Errorf("invalid array, expected ']' got %v", r) + } + + explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + } + + if r != '{' { + return nil, fmt.Errorf("invalid array, expected '{': %v", err) + } + + implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} + + // Consume all initial opening brackets. This provides number of dimensions. + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '{' { + implicitDimensions[len(implicitDimensions)-1].Length = 1 + implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1}) + } else { + buf.UnreadRune() + break + } + } + currentDim := len(implicitDimensions) - 1 + counterDim := currentDim + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + switch r { + case '{': + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + currentDim++ + case ',': + case '}': + currentDim-- + if currentDim < counterDim { + counterDim = currentDim + } + default: + buf.UnreadRune() + value, err := arrayParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid array value: %v", err) + } + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + uta.Elements = append(uta.Elements, value) + } + + if currentDim < 0 { + break + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + if len(uta.Elements) == 0 { + uta.Dimensions = nil + } else if len(explicitDimensions) > 0 { + uta.Dimensions = explicitDimensions + } else { + uta.Dimensions = implicitDimensions + } + + return uta, nil +} + +func skipWhitespace(buf *bytes.Buffer) { + var r rune + var err error + for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() { + } + + if err != io.EOF { + buf.UnreadRune() + } +} + +func arrayParseValue(buf *bytes.Buffer) (string, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + if r == '"' { + return arrayParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case ',', '}': + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +func arrayParseQuotedValue(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + buf.UnreadRune() + return s.String(), nil + } + s.WriteRune(r) + } +} + +func arrayParseInteger(buf *bytes.Buffer) (int32, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return 0, err + } + + if '0' <= r && r <= '9' { + s.WriteRune(r) + } else { + buf.UnreadRune() + n, err := strconv.ParseInt(s.String(), 10, 32) + if err != nil { + return 0, err + } + return int32(n), nil + } + } +} + +func EncodeTextArrayDimensions(w io.Writer, dimensions []ArrayDimension) error { + var customDimensions bool + for _, dim := range dimensions { + if dim.LowerBound != 1 { + customDimensions = true + } + } + + if !customDimensions { + return nil + } + + for _, dim := range dimensions { + err := pgio.WriteByte(w, '[') + if err != nil { + return err + } + + _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound), 10)) + if err != nil { + return err + } + + err = pgio.WriteByte(w, ':') + if err != nil { + return err + } + + _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)) + if err != nil { + return err + } + + err = pgio.WriteByte(w, ']') + if err != nil { + return err + } + } + + return pgio.WriteByte(w, '=') +} diff --git a/array_test.go b/array_test.go new file mode 100644 index 00000000..5e5f00e7 --- /dev/null +++ b/array_test.go @@ -0,0 +1,98 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestParseUntypedTextArray(t *testing.T) { + tests := []struct { + source string + result pgtype.UntypedTextArray + }{ + { + source: "{}", + result: pgtype.UntypedTextArray{ + Elements: nil, + Dimensions: nil, + }, + }, + { + source: "{1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{a,b}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b"}, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + }, + }, + { + source: `{"NULL"}`, + result: pgtype.UntypedTextArray{ + Elements: []string{"NULL"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: `{"He said, \"Hello.\""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{`He said, "Hello."`}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{{a,b},{c,d},{e,f}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f"}, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + }, + }, + { + source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 1}, + {Length: 3, LowerBound: 1}, + {Length: 2, LowerBound: 1}, + }, + }, + }, + { + source: "[4:4]={1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 4}}, + }, + }, + { + source: "[4:5][2:3]={{a,b},{c,d}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d"}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + }, + }, + } + + for i, tt := range tests { + r, err := pgtype.ParseUntypedTextArray(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(*r, tt.result) { + t.Errorf("%d: expected %+v to be parsed to %+v, but it was %+v", i, tt.source, tt.result, *r) + } + } +} diff --git a/bool.go b/bool.go new file mode 100644 index 00000000..81c72472 --- /dev/null +++ b/bool.go @@ -0,0 +1,166 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Bool struct { + Bool bool + Status Status +} + +func (b *Bool) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Bool: + *b = value + case bool: + *b = Bool{Bool: value, Status: Present} + case string: + bb, err := strconv.ParseBool(value) + if err != nil { + return err + } + *b = Bool{Bool: bb, Status: Present} + default: + if originalSrc, ok := underlyingBoolType(src); ok { + return b.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bool", value) + } + + return nil +} + +func (b *Bool) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *bool: + if b.Status != Present { + return fmt.Errorf("cannot assign %v to %T", b, dst) + } + *v = b.Bool + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if b.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return b.AssignTo(el.Interface()) + case reflect.Bool: + if b.Status != Present { + return fmt.Errorf("cannot assign %v to %T", b, dst) + } + el.SetBool(b.Bool) + return nil + } + } + return fmt.Errorf("cannot put decode %v into %T", b, dst) + } + + return nil +} + +func (b *Bool) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *b = Bool{Status: Null} + return nil + } + + if size != 1 { + return fmt.Errorf("invalid length for bool: %v", size) + } + + byt, err := pgio.ReadByte(r) + if err != nil { + return err + } + + *b = Bool{Bool: byt == 't', Status: Present} + return nil +} + +func (b *Bool) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *b = Bool{Status: Null} + return nil + } + + if size != 1 { + return fmt.Errorf("invalid length for bool: %v", size) + } + + byt, err := pgio.ReadByte(r) + if err != nil { + return err + } + + *b = Bool{Bool: byt == 1, Status: Present} + return nil +} + +func (b Bool) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, b.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 1) + if err != nil { + return nil + } + + var buf []byte + if b.Bool { + buf = []byte{'t'} + } else { + buf = []byte{'f'} + } + + _, err = w.Write(buf) + return err +} + +func (b Bool) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, b.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 1) + if err != nil { + return nil + } + + var buf []byte + if b.Bool { + buf = []byte{1} + } else { + buf = []byte{0} + } + + _, err = w.Write(buf) + return err +} diff --git a/bool_test.go b/bool_test.go new file mode 100644 index 00000000..53df1747 --- /dev/null +++ b/bool_test.go @@ -0,0 +1,43 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestBoolTranscode(t *testing.T) { + testSuccessfulTranscode(t, "bool", []interface{}{ + pgtype.Bool{Bool: false, Status: pgtype.Present}, + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: false, Status: pgtype.Null}, + }) +} + +func TestBoolConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Bool + }{ + {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Bool + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/convert.go b/convert.go new file mode 100644 index 00000000..3f3d9e5f --- /dev/null +++ b/convert.go @@ -0,0 +1,239 @@ +package pgtype + +import ( + "fmt" + "math" + "reflect" + "time" +) + +const maxUint = ^uint(0) +const maxInt = int(maxUint >> 1) +const minInt = -maxInt - 1 + +// underlyingIntType gets the underlying type that can be converted to Int2, Int4, or Int8 +func underlyingIntType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Int: + convVal := int(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int8: + convVal := int8(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int16: + convVal := int16(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int32: + convVal := int32(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int64: + convVal := int64(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint: + convVal := uint(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint8: + convVal := uint8(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint16: + convVal := uint16(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint32: + convVal := uint32(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint64: + convVal := uint64(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBoolType gets the underlying type that can be converted to Bool +func underlyingBoolType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Bool: + convVal := refVal.Bool() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingTimeType gets the underlying type that can be converted to time.Time +func underlyingTimeType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return time.Time{}, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + timeType := reflect.TypeOf(time.Time{}) + if refVal.Type().ConvertibleTo(timeType) { + return refVal.Convert(timeType).Interface(), true + } + + return time.Time{}, false +} + +// underlyingSliceType gets the underlying slice type +func underlyingSliceType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + baseSliceType := reflect.SliceOf(refVal.Type().Elem()) + if refVal.Type().ConvertibleTo(baseSliceType) { + convVal := refVal.Convert(baseSliceType) + return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() + } + } + + return nil, false +} + +func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { + if srcStatus == Present { + switch v := dst.(type) { + case *int: + if srcVal < int64(minInt) { + return fmt.Errorf("%d is less than minimum value for int", srcVal) + } else if srcVal > int64(maxInt) { + return fmt.Errorf("%d is greater than maximum value for int", srcVal) + } + *v = int(srcVal) + case *int8: + if srcVal < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", srcVal) + } else if srcVal > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", srcVal) + } + *v = int8(srcVal) + case *int16: + if srcVal < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", srcVal) + } else if srcVal > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", srcVal) + } + *v = int16(srcVal) + case *int32: + if srcVal < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", srcVal) + } else if srcVal > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", srcVal) + } + *v = int32(srcVal) + case *int64: + if srcVal < math.MinInt64 { + return fmt.Errorf("%d is less than minimum value for int64", srcVal) + } else if srcVal > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for int64", srcVal) + } + *v = int64(srcVal) + case *uint: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint", srcVal) + } else if uint64(srcVal) > uint64(maxUint) { + return fmt.Errorf("%d is greater than maximum value for uint", srcVal) + } + *v = uint(srcVal) + case *uint8: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint8", srcVal) + } else if srcVal > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) + } + *v = uint8(srcVal) + case *uint16: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) + } + *v = uint16(srcVal) + case *uint32: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) + } + *v = uint32(srcVal) + case *uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint64", srcVal) + } + *v = uint64(srcVal) + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return int64AssignTo(srcVal, srcStatus, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if el.OverflowInt(int64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetInt(int64(srcVal)) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for %T", srcVal, dst) + } + if el.OverflowUint(uint64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetUint(uint64(srcVal)) + return nil + } + } + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Present, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) +} diff --git a/date.go b/date.go new file mode 100644 index 00000000..f3e3e4c6 --- /dev/null +++ b/date.go @@ -0,0 +1,191 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "time" + + "github.com/jackc/pgx/pgio" +) + +type Date struct { + Time time.Time + Status Status + InfinityModifier +} + +const ( + negativeInfinityDayOffset = -2147483648 + infinityDayOffset = 2147483647 +) + +func (d *Date) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Date: + *d = value + case time.Time: + *d = Date{Time: value, Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return d.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Date", value) + } + + return nil +} + +func (d *Date) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *time.Time: + if d.Status != Present || d.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", d, dst) + } + *v = d.Time + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if d.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return d.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot decode %v into %T", d, dst) + } + + return nil +} + +func (d *Date) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *d = Date{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + sbuf := string(buf) + switch sbuf { + case "infinity": + *d = Date{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *d = Date{Status: Present, InfinityModifier: -Infinity} + default: + t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) + if err != nil { + return err + } + + *d = Date{Time: t, Status: Present} + } + + return nil +} + +func (d *Date) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *d = Date{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for date: %v", size) + } + + dayOffset, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + switch dayOffset { + case infinityDayOffset: + *d = Date{Status: Present, InfinityModifier: Infinity} + case negativeInfinityDayOffset: + *d = Date{Status: Present, InfinityModifier: -Infinity} + default: + t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) + *d = Date{Time: t, Status: Present} + } + + return nil +} + +func (d Date) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, d.Status); done { + return err + } + + var s string + + switch d.InfinityModifier { + case None: + s = d.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + + _, err = w.Write([]byte(s)) + return err +} + +func (d Date) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, d.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + var daysSinceDateEpoch int32 + switch d.InfinityModifier { + case None: + tUnix := time.Date(d.Time.Year(), d.Time.Month(), d.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() + dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + + secSinceDateEpoch := tUnix - dateEpoch + daysSinceDateEpoch = int32(secSinceDateEpoch / 86400) + case Infinity: + daysSinceDateEpoch = infinityDayOffset + case NegativeInfinity: + daysSinceDateEpoch = negativeInfinityDayOffset + } + + _, err = pgio.WriteInt32(w, daysSinceDateEpoch) + return err +} diff --git a/date_test.go b/date_test.go new file mode 100644 index 00000000..c3e971d0 --- /dev/null +++ b/date_test.go @@ -0,0 +1,51 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestDateTranscode(t *testing.T) { + testSuccessfulTranscode(t, "date", []interface{}{ + pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Status: pgtype.Null}, + pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }) +} + +func TestDateConvertFrom(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Date + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.Date + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} diff --git a/extra-interface.txt b/extra-interface.txt new file mode 100644 index 00000000..16453823 --- /dev/null +++ b/extra-interface.txt @@ -0,0 +1,3 @@ +Can pass function to get inet data and function to get oid/name mapping as optional interface with io.Reader or io.Writer + +Could be useful for arrays of types without defined OIDs like hstore. diff --git a/int2.go b/int2.go new file mode 100644 index 00000000..2da8a96d --- /dev/null +++ b/int2.go @@ -0,0 +1,167 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int2 struct { + Int int16 + Status Status +} + +func (i *Int2) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int2: + *i = value + case int8: + *i = Int2{Int: int16(value), Status: Present} + case uint8: + *i = Int2{Int: int16(value), Status: Present} + case int16: + *i = Int2{Int: int16(value), Status: Present} + case uint16: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case int32: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case uint32: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case int64: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case uint64: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case int: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case uint: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 16) + if err != nil { + return err + } + *i = Int2{Int: int16(num), Status: Present} + default: + if originalSrc, ok := underlyingIntType(src); ok { + return i.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int2", value) + } + + return nil +} + +func (i *Int2) AssignTo(dst interface{}) error { + return int64AssignTo(int64(i.Int), i.Status, dst) +} + +func (i *Int2) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int2{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 16) + if err != nil { + return err + } + + *i = Int2{Int: int16(n), Status: Present} + return nil +} + +func (i *Int2) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int2{Status: Null} + return nil + } + + if size != 2 { + return fmt.Errorf("invalid length for int2: %v", size) + } + + n, err := pgio.ReadInt16(r) + if err != nil { + return err + } + + *i = Int2{Int: int16(n), Status: Present} + return nil +} + +func (i Int2) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + s := strconv.FormatInt(int64(i.Int), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int2) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = pgio.WriteInt16(w, i.Int) + return err +} diff --git a/int2_test.go b/int2_test.go new file mode 100644 index 00000000..a8493a16 --- /dev/null +++ b/int2_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt2Transcode(t *testing.T) { + testSuccessfulTranscode(t, "int2", []interface{}{ + pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, + pgtype.Int2{Int: -1, Status: pgtype.Present}, + pgtype.Int2{Int: 0, Status: pgtype.Present}, + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}, + pgtype.Int2{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt2ConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Int2 + }{ + {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int2 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/int2array.go b/int2array.go new file mode 100644 index 00000000..86375516 --- /dev/null +++ b/int2array.go @@ -0,0 +1,308 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Int2Array struct { + Elements []Int2 + Dimensions []ArrayDimension + Status Status +} + +func (a *Int2Array) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int2Array: + *a = value + case []int16: + if value == nil { + *a = Int2Array{Status: Null} + } else if len(value) == 0 { + *a = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *a = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint16: + if value == nil { + *a = Int2Array{Status: Null} + } else if len(value) == 0 { + *a = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *a = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return a.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int2", value) + } + + return nil +} + +func (a *Int2Array) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *[]int16: + if a.Status == Present { + *v = make([]int16, len(a.Elements)) + for i := range a.Elements { + if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + case *[]uint16: + if a.Status == Present { + *v = make([]uint16, len(a.Elements)) + for i := range a.Elements { + if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + default: + return fmt.Errorf("cannot put decode %v into %T", a, dst) + } + + return nil +} + +func (a *Int2Array) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *a = Int2Array{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Int2 + + if len(uta.Elements) > 0 { + elements = make([]Int2, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int2 + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *a = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (a *Int2Array) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *a = Int2Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *a = Int2Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int2, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *a = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (a *Int2Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, a.Status); done { + return err + } + + if len(a.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, a.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(a.Dimensions)) + dimElemCounts[len(a.Dimensions)-1] = int(a.Dimensions[len(a.Dimensions)-1].Length) + for i := len(a.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(a.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range a.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (a *Int2Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, a.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range a.Elements { + err := a.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if a.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = Int2OID + arrayHeader.Dimensions = a.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/int2array_test.go b/int2array_test.go new file mode 100644 index 00000000..5ea81990 --- /dev/null +++ b/int2array_test.go @@ -0,0 +1,87 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt2ArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "int2[]", []interface{}{ + &pgtype.Int2Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int2Array{Status: pgtype.Null}, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: 2, Status: pgtype.Present}, + pgtype.Int2{Int: 3, Status: pgtype.Present}, + pgtype.Int2{Int: 4, Status: pgtype.Present}, + pgtype.Int2{Status: pgtype.Null}, + pgtype.Int2{Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: 2, Status: pgtype.Present}, + pgtype.Int2{Int: 3, Status: pgtype.Present}, + pgtype.Int2{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +// func TestInt2ConvertFrom(t *testing.T) { +// type _int8 int8 + +// successfulTests := []struct { +// source interface{} +// result pgtype.Int2 +// }{ +// {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// } + +// for i, tt := range successfulTests { +// var r pgtype.Int2 +// err := r.ConvertFrom(tt.source) +// if err != nil { +// t.Errorf("%d: %v", i, err) +// } + +// if r != tt.result { +// t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) +// } +// } +// } diff --git a/int4.go b/int4.go new file mode 100644 index 00000000..84c45522 --- /dev/null +++ b/int4.go @@ -0,0 +1,158 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int4 struct { + Int int32 + Status Status +} + +func (i *Int4) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int4: + *i = value + case int8: + *i = Int4{Int: int32(value), Status: Present} + case uint8: + *i = Int4{Int: int32(value), Status: Present} + case int16: + *i = Int4{Int: int32(value), Status: Present} + case uint16: + *i = Int4{Int: int32(value), Status: Present} + case int32: + *i = Int4{Int: int32(value), Status: Present} + case uint32: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case int64: + if value < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case uint64: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case int: + if value < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case uint: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 32) + if err != nil { + return err + } + *i = Int4{Int: int32(num), Status: Present} + default: + if originalSrc, ok := underlyingIntType(src); ok { + return i.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (i *Int4) AssignTo(dst interface{}) error { + return int64AssignTo(int64(i.Int), i.Status, dst) +} + +func (i *Int4) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int4{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 32) + if err != nil { + return err + } + + *i = Int4{Int: int32(n), Status: Present} + return nil +} + +func (i *Int4) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int4{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for int4: %v", size) + } + + n, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + *i = Int4{Int: n, Status: Present} + return nil +} + +func (i Int4) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + s := strconv.FormatInt(int64(i.Int), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int4) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, i.Int) + return err +} diff --git a/int4_test.go b/int4_test.go new file mode 100644 index 00000000..04411849 --- /dev/null +++ b/int4_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt4Transcode(t *testing.T) { + testSuccessfulTranscode(t, "int4", []interface{}{ + pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, + pgtype.Int4{Int: -1, Status: pgtype.Present}, + pgtype.Int4{Int: 0, Status: pgtype.Present}, + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: math.MaxInt32, Status: pgtype.Present}, + pgtype.Int4{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt4ConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Int4 + }{ + {source: int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int4 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/int8.go b/int8.go new file mode 100644 index 00000000..c0e14e44 --- /dev/null +++ b/int8.go @@ -0,0 +1,149 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int8 struct { + Int int64 + Status Status +} + +func (i *Int8) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int8: + *i = value + case int8: + *i = Int8{Int: int64(value), Status: Present} + case uint8: + *i = Int8{Int: int64(value), Status: Present} + case int16: + *i = Int8{Int: int64(value), Status: Present} + case uint16: + *i = Int8{Int: int64(value), Status: Present} + case int32: + *i = Int8{Int: int64(value), Status: Present} + case uint32: + *i = Int8{Int: int64(value), Status: Present} + case int64: + *i = Int8{Int: int64(value), Status: Present} + case uint64: + if value > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *i = Int8{Int: int64(value), Status: Present} + case int: + if int64(value) < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + if int64(value) > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *i = Int8{Int: int64(value), Status: Present} + case uint: + if uint64(value) > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *i = Int8{Int: int64(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + *i = Int8{Int: num, Status: Present} + default: + if originalSrc, ok := underlyingIntType(src); ok { + return i.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (i *Int8) AssignTo(dst interface{}) error { + return int64AssignTo(int64(i.Int), i.Status, dst) +} + +func (i *Int8) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int8{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 64) + if err != nil { + return err + } + + *i = Int8{Int: n, Status: Present} + return nil +} + +func (i *Int8) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int8{Status: Null} + return nil + } + + if size != 8 { + return fmt.Errorf("invalid length for int8: %v", size) + } + + n, err := pgio.ReadInt64(r) + if err != nil { + return err + } + + *i = Int8{Int: n, Status: Present} + return nil +} + +func (i Int8) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + s := strconv.FormatInt(i.Int, 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int8) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 8) + if err != nil { + return err + } + + _, err = pgio.WriteInt64(w, i.Int) + return err +} diff --git a/int8_test.go b/int8_test.go new file mode 100644 index 00000000..ba246224 --- /dev/null +++ b/int8_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt8Transcode(t *testing.T) { + testSuccessfulTranscode(t, "int8", []interface{}{ + pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, + pgtype.Int8{Int: -1, Status: pgtype.Present}, + pgtype.Int8{Int: 0, Status: pgtype.Present}, + pgtype.Int8{Int: 1, Status: pgtype.Present}, + pgtype.Int8{Int: math.MaxInt64, Status: pgtype.Present}, + pgtype.Int8{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt8ConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Int8 + }{ + {source: int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int8 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype.go b/pgtype.go new file mode 100644 index 00000000..f9833363 --- /dev/null +++ b/pgtype.go @@ -0,0 +1,102 @@ +package pgtype + +import ( + "errors" + "io" + + "github.com/jackc/pgx/pgio" +) + +// PostgreSQL oids for common types +const ( + BoolOID = 16 + ByteaOID = 17 + CharOID = 18 + NameOID = 19 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + TidOID = 27 + XidOID = 28 + CidOID = 29 + JSONOID = 114 + CidrOID = 650 + CidrArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + UnknownOID = 705 + InetOID = 869 + BoolArrayOID = 1000 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + ByteaArrayOID = 1001 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + AclItemOID = 1033 + AclItemArrayOID = 1034 + InetArrayOID = 1041 + VarcharOID = 1043 + DateOID = 1082 + TimestampOID = 1114 + TimestampArrayOID = 1115 + TimestampTzOID = 1184 + TimestampTzArrayOID = 1185 + RecordOID = 2249 + UUIDOID = 2950 + JSONBOID = 3802 +) + +type Status byte + +const ( + Undefined Status = iota + Null + Present +) + +type InfinityModifier int8 + +const ( + Infinity InfinityModifier = 1 + None InfinityModifier = 0 + NegativeInfinity InfinityModifier = -Infinity +) + +type Value interface { + ConvertFrom(src interface{}) error + AssignTo(dst interface{}) error +} + +type BinaryDecoder interface { + DecodeBinary(r io.Reader) error +} + +type TextDecoder interface { + DecodeText(r io.Reader) error +} + +type BinaryEncoder interface { + EncodeBinary(w io.Writer) error +} + +type TextEncoder interface { + EncodeText(w io.Writer) error +} + +var errUndefined = errors.New("cannot encode status undefined") + +func encodeNotPresent(w io.Writer, status Status) (done bool, err error) { + switch status { + case Undefined: + return true, errUndefined + case Null: + _, err = pgio.WriteInt32(w, -1) + return true, err + } + return false, nil +} diff --git a/pgtype_test.go b/pgtype_test.go new file mode 100644 index 00000000..a1a575f7 --- /dev/null +++ b/pgtype_test.go @@ -0,0 +1,108 @@ +package pgtype_test + +import ( + "fmt" + "io" + "os" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func mustConnectPgx(t testing.TB) *pgx.Conn { + config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + conn, err := pgx.Connect(config) + if err != nil { + t.Fatal(err) + } + + return conn +} + +func mustClose(t testing.TB, conn interface { + Close() error +}) { + err := conn.Close() + if err != nil { + t.Fatal(err) + } +} + +type forceTextEncoder struct { + e pgtype.TextEncoder +} + +func (f forceTextEncoder) EncodeText(w io.Writer) error { + return f.e.EncodeText(w) +} + +type forceBinaryEncoder struct { + e pgtype.BinaryEncoder +} + +func (f forceBinaryEncoder) EncodeBinary(w io.Writer) error { + return f.e.EncodeBinary(w) +} + +func forceEncoder(e interface{}, formatCode int16) interface{} { + switch formatCode { + case pgx.TextFormatCode: + return forceTextEncoder{e: e.(pgtype.TextEncoder)} + case pgx.BinaryFormatCode: + return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} + default: + panic("bad encoder") + } +} + +func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { + testSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} diff --git a/text_element.go b/text_element.go new file mode 100644 index 00000000..1a585d08 --- /dev/null +++ b/text_element.go @@ -0,0 +1,112 @@ +package pgtype + +import ( + "bytes" + "errors" + "io" + + "github.com/jackc/pgx/pgio" +) + +// TextElementWriter is a wrapper that makes TextEncoders composable into other +// TextEncoders. TextEncoder first writes the length of the subsequent value. +// This is not necessary when the value is part of another value such as an +// array. TextElementWriter requires one int32 to be written first which it +// ignores. No other integer writes are valid. +type TextElementWriter struct { + w io.Writer + lengthHeaderIgnored bool +} + +func NewTextElementWriter(w io.Writer) *TextElementWriter { + return &TextElementWriter{w: w} +} + +func (w *TextElementWriter) WriteUint16(n uint16) (int, error) { + return 0, errors.New("WriteUint16 should never be called on TextElementWriter") +} + +func (w *TextElementWriter) WriteUint32(n uint32) (int, error) { + if !w.lengthHeaderIgnored { + w.lengthHeaderIgnored = true + + if int32(n) == -1 { + return io.WriteString(w.w, "NULL") + } + + return 4, nil + } + + return 0, errors.New("WriteUint32 should only be called once on TextElementWriter") +} + +func (w *TextElementWriter) WriteUint64(n uint64) (int, error) { + if w.lengthHeaderIgnored { + return pgio.WriteUint64(w.w, n) + } + + return 0, errors.New("WriteUint64 should never be called on TextElementWriter") +} + +func (w *TextElementWriter) Write(buf []byte) (int, error) { + if w.lengthHeaderIgnored { + return w.w.Write(buf) + } + + return 0, errors.New("int32 must be written first") +} + +func (w *TextElementWriter) Reset() { + w.lengthHeaderIgnored = false +} + +// TextElementReader is a wrapper that makes TextDecoders composable into other +// TextDecoders. TextEncoders first read the length of the subsequent value. +// This length value is not present when the value is part of another value such +// as an array. TextElementReader provides a substitute length value from the +// length of the string. No other integer reads are valid. Each time DecodeText +// is called with a TextElementReader as the source the TextElementReader must +// first have Reset called with the new element string data. +type TextElementReader struct { + buf *bytes.Buffer + lengthHeaderIgnored bool +} + +func NewTextElementReader(r io.Reader) *TextElementReader { + return &TextElementReader{buf: &bytes.Buffer{}} +} + +func (r *TextElementReader) ReadUint16() (uint16, error) { + return 0, errors.New("ReadUint16 should never be called on TextElementReader") +} + +func (r *TextElementReader) ReadUint32() (uint32, error) { + if !r.lengthHeaderIgnored { + r.lengthHeaderIgnored = true + if r.buf.String() == "NULL" { + n32 := int32(-1) + return uint32(n32), nil + } + return uint32(r.buf.Len()), nil + } + + return 0, errors.New("ReadUint32 should only be called once on TextElementReader") +} + +func (r *TextElementReader) WriteUint64(n uint64) (int, error) { + return 0, errors.New("ReadUint64 should never be called on TextElementReader") +} + +func (r *TextElementReader) Read(buf []byte) (int, error) { + if r.lengthHeaderIgnored { + return r.buf.Read(buf) + } + + return 0, errors.New("int32 must be read first") +} + +func (r *TextElementReader) Reset(s string) { + r.lengthHeaderIgnored = false + r.buf.Reset() + r.buf.WriteString(s) +} diff --git a/timestamptz.go b/timestamptz.go new file mode 100644 index 00000000..cc33b296 --- /dev/null +++ b/timestamptz.go @@ -0,0 +1,203 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "time" + + "github.com/jackc/pgx/pgio" +) + +const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" +const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" +const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" +const microsecFromUnixEpochToY2K = 946684800 * 1000000 + +const ( + negativeInfinityMicrosecondOffset = -9223372036854775808 + infinityMicrosecondOffset = 9223372036854775807 +) + +type Timestamptz struct { + Time time.Time + Status Status + InfinityModifier +} + +func (t *Timestamptz) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Timestamptz: + *t = value + case time.Time: + *t = Timestamptz{Time: value, Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return t.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamptz", value) + } + + return nil +} + +func (t *Timestamptz) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *time.Time: + if t.Status != Present || t.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", t, dst) + } + *v = t.Time + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if t.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return t.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot assign %v into %T", t, dst) + } + + return nil +} + +func (t *Timestamptz) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *t = Timestamptz{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + sbuf := string(buf) + switch sbuf { + case "infinity": + *t = Timestamptz{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *t = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + var format string + if sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+' { + format = pgTimestamptzSecondFormat + } else if sbuf[len(sbuf)-6] == '-' || sbuf[len(sbuf)-6] == '+' { + format = pgTimestamptzMinuteFormat + } else { + format = pgTimestamptzHourFormat + } + + tim, err := time.Parse(format, sbuf) + if err != nil { + return err + } + + *t = Timestamptz{Time: tim, Status: Present} + } + + return nil +} + +func (t *Timestamptz) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *t = Timestamptz{Status: Null} + return nil + } + + if size != 8 { + return fmt.Errorf("invalid length for timestamptz: %v", size) + } + + microsecSinceY2K, err := pgio.ReadInt64(r) + if err != nil { + return err + } + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + *t = Timestamptz{Status: Present, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + *t = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + *t = Timestamptz{Time: tim, Status: Present} + } + + return nil +} + +func (t Timestamptz) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, t.Status); done { + return err + } + + var s string + + switch t.InfinityModifier { + case None: + s = t.Time.UTC().Format(pgTimestamptzSecondFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + + _, err = w.Write([]byte(s)) + return err +} + +func (t Timestamptz) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, t.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 8) + if err != nil { + return err + } + + var microsecSinceY2K int64 + switch t.InfinityModifier { + case None: + microsecSinceUnixEpoch := t.Time.Unix()*1000000 + int64(t.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + _, err = pgio.WriteInt64(w, microsecSinceY2K) + return err +} diff --git a/timestamptz_test.go b/timestamptz_test.go new file mode 100644 index 00000000..795195f8 --- /dev/null +++ b/timestamptz_test.go @@ -0,0 +1,60 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTimestamptzTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ + pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Status: pgtype.Null}, + pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Timestamptz) + bt := b.(pgtype.Timestamptz) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + }) +} + +func TestTimestamptzConvertFrom(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Timestamptz + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Timestamptz + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +}